1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "ruy/asm_helpers.h"
19 #include "ruy/check_macros.h"
20 #include "ruy/kernel_arm.h"
21 #include "ruy/opt_set.h"
22 #include "ruy/platform.h"
23 #include "ruy/profiler/instrumentation.h"
24
25 namespace ruy {
26
27 #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
28
29 #define RUY_ASM_LABEL_STORE_UINT8 91
30 #define RUY_ASM_LABEL_STORE_INT8 92
31 #define RUY_ASM_LABEL_STORE_INT16 93
32 #define RUY_ASM_LABEL_STORE_INT32 94
33 #define RUY_ASM_LABEL_AFTER_STORE 99
34
35 #define RUY_OFFSET_BIAS 0
36 #define RUY_OFFSET_LHS_SUMS 8
37 #define RUY_OFFSET_RHS_SUMS 16
38 #define RUY_OFFSET_LHS_BASE_PTR 24
39 #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
40 #define RUY_OFFSET_MULTIPLIER_EXPONENT 40
41 #define RUY_OFFSET_RHS_BASE_PTR 48
42 #define RUY_OFFSET_DST_BASE_PTR 56
43 #define RUY_OFFSET_LHS_ZERO_POINT 64
44 #define RUY_OFFSET_RHS_ZERO_POINT 68
45 #define RUY_OFFSET_DST_ZERO_POINT 72
46 #define RUY_OFFSET_PROD_ZP_DEPTH 76
47 #define RUY_OFFSET_START_ROW 80
48 #define RUY_OFFSET_START_COL 84
49 #define RUY_OFFSET_LAST_ROW 88
50 #define RUY_OFFSET_LAST_COL 92
51 #define RUY_OFFSET_DST_ROWS 96
52 #define RUY_OFFSET_DST_COLS 100
53 #define RUY_OFFSET_LHS_STRIDE 104
54 #define RUY_OFFSET_RHS_STRIDE 108
55 #define RUY_OFFSET_DST_STRIDE 112
56 #define RUY_OFFSET_DEPTH 116
57 #define RUY_OFFSET_CLAMP_MIN 120
58 #define RUY_OFFSET_CLAMP_MAX 124
59 #define RUY_OFFSET_FLAGS 128
60
61 template <typename Params>
CheckOffsetsInKernelParams8bit(const Params &)62 void CheckOffsetsInKernelParams8bit(const Params&) {
63 static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
64 "");
65 static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
66 "");
67 static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
68 "");
69 static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
70 "");
71 static_assert(offsetof(Params, multiplier_fixedpoint) ==
72 RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
73 "");
74 static_assert(
75 offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
76 "");
77 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
78 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
79 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
80 static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
81 static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
82 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
83 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
84 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
85 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
86 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
87 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
88 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
89 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
90 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
91 }
92
93 // Fast-int8-trick kernel, similar to this production gemmlowp kernel:
94 // NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
95 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
96 //
97 // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
98 // since these are 64-bit, out-of-order and without dotprod support.
Kernel8bitNeon(const KernelParams8bit<4,4> & params)99 void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
100 profiler::ScopeLabel label("Kernel (kNeon)");
101 CheckOffsetsInKernelParams8bit(params);
102
103 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
104 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
105 const std::int8_t* lhs_ptr = lhs_col_ptr;
106 const std::int8_t* rhs_ptr = rhs_col_ptr;
107 void* dst_col_ptr = params.dst_base_ptr;
108 void* dst_ptr = dst_col_ptr;
109 int row = params.start_row;
110 int col = params.start_col;
111
112 // The asm kernel below has the following NEON register allocation:
113 //
114 // v16 -- v31 are int32 accumulators.
115 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
116 // v4 -- v7 from RHS:
117 //
118 // int8 RHS 16x4 block
119 // /-----------------------------------------|
120 // |v4.b[0] ... v7.b[0] |
121 // | ... ... |
122 // |v4.b[15] ... v7.b[15] |
123 // \-----------------------------------------/
124 // int8 LHS 4x16 block
125 // /---------------------\ /-----------------------------------------|
126 // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
127 // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
128 // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
129 // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
130 // \---------------------/ \-----------------------------------------/
131 // int32 accumulators 4x4 block
132 //
133 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
134 // optimization for this kernel.
135 asm volatile(
136 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
137
138 // clang-format off
139
140 // Load some parameters into registers.
141 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
142 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
143 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
144 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
145 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
146 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
147 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
148 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
149
150 // Load the first 64 bytes of LHS and RHS data.
151 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
152 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
153 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
154 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
155 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
156 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
157 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
158 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
159
160 // Clear accumulators.
161 RUY_MAKE_ZERO(v16)
162 RUY_MAKE_ZERO(v17)
163 RUY_MAKE_ZERO(v18)
164 RUY_MAKE_ZERO(v19)
165 RUY_MAKE_ZERO(v20)
166 RUY_MAKE_ZERO(v21)
167 RUY_MAKE_ZERO(v22)
168 RUY_MAKE_ZERO(v23)
169 RUY_MAKE_ZERO(v24)
170 RUY_MAKE_ZERO(v25)
171 RUY_MAKE_ZERO(v26)
172 RUY_MAKE_ZERO(v27)
173 RUY_MAKE_ZERO(v28)
174 RUY_MAKE_ZERO(v29)
175 RUY_MAKE_ZERO(v30)
176 RUY_MAKE_ZERO(v31)
177
178 // w1 is the number of levels of depth that we have already loaded
179 // LHS and RHS data for. Corresponding to the initial ld1 instructions
180 // above, this is currently 16.
181 "mov w1, #16\n"
182
183 // Perform the first few multiply-adds on the data that we have already
184 // loaded.
185 "smull v8.8h, v0.8b, v4.8b\n"
186 "smull v9.8h, v1.8b, v4.8b\n"
187 "smull v10.8h, v2.8b, v4.8b\n"
188 "smull v11.8h, v3.8b, v4.8b\n"
189 "smull v12.8h, v0.8b, v5.8b\n"
190 "smull v13.8h, v1.8b, v5.8b\n"
191 "smull v14.8h, v2.8b, v5.8b\n"
192 "smull v15.8h, v3.8b, v5.8b\n"
193
194 // Multiply-accumulate second-half, again into the same
195 // 16bit local accumulator registers. This is where we
196 // take advantage of having int8 instead of uint8 and therefore
197 // being able to accumulate two products into int16.
198 "smlal2 v8.8h, v0.16b, v4.16b\n"
199 "smlal2 v9.8h, v1.16b, v4.16b\n"
200 "smlal2 v10.8h, v2.16b, v4.16b\n"
201 "smlal2 v11.8h, v3.16b, v4.16b\n"
202 "smlal2 v12.8h, v0.16b, v5.16b\n"
203 "smlal2 v13.8h, v1.16b, v5.16b\n"
204 "smlal2 v14.8h, v2.16b, v5.16b\n"
205 "smlal2 v15.8h, v3.16b, v5.16b\n"
206
207
208 // Main loop of the whole GEMM, over rows and columns of the
209 // destination matrix.
210 "1:\n"
211
212 // Reminder - w1 is how many levels of depth we have already loaded
213 // data for, w12 is the total depth.
214 "cmp w1, w12\n"
215 "beq 79f\n"
216
217 "2:\n"
218
219 // Some multiplications and 16-bit accumulation were already done above,
220 // so we start right away in the middle.
221 "sadalp v16.4s, v8.8h\n"
222 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
223 "smull v8.8h, v0.8b, v6.8b\n"
224 "sadalp v17.4s, v9.8h\n"
225 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
226 "smull v9.8h, v1.8b, v6.8b\n"
227 "sadalp v18.4s, v10.8h\n"
228 "smull v10.8h, v2.8b, v6.8b\n"
229 "sadalp v19.4s, v11.8h\n"
230 "smull v11.8h, v3.8b, v6.8b\n"
231 "sadalp v20.4s, v12.8h\n"
232 "smull v12.8h, v0.8b, v7.8b\n"
233 "sadalp v21.4s, v13.8h\n"
234 "smull v13.8h, v1.8b, v7.8b\n"
235 "sadalp v22.4s, v14.8h\n"
236 "smull v14.8h, v2.8b, v7.8b\n"
237 "sadalp v23.4s, v15.8h\n"
238 "smull v15.8h, v3.8b, v7.8b\n"
239
240 // Multiply-accumulate second-half, again into the same
241 // 16bit local accumulator registers. This is where we
242 // take advantage of having int8 instead of uint8 and therefore
243 // being able to accumulate two products into int16.
244 "smlal2 v8.8h, v0.16b, v6.16b\n"
245 "smlal2 v9.8h, v1.16b, v6.16b\n"
246 "smlal2 v10.8h, v2.16b, v6.16b\n"
247 "smlal2 v11.8h, v3.16b, v6.16b\n"
248
249 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
250
251 "smlal2 v12.8h, v0.16b, v7.16b\n"
252 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
253 "smlal2 v13.8h, v1.16b, v7.16b\n"
254 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
255 "smlal2 v14.8h, v2.16b, v7.16b\n"
256 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
257 "smlal2 v15.8h, v3.16b, v7.16b\n"
258 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
259
260 "sadalp v24.4s, v8.8h\n"
261 "smull v8.8h, v0.8b, v4.8b\n"
262 "sadalp v25.4s, v9.8h\n"
263 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
264 "smull v9.8h, v1.8b, v4.8b\n"
265 "sadalp v26.4s, v10.8h\n"
266 "smull v10.8h, v2.8b, v4.8b\n"
267 "sadalp v27.4s, v11.8h\n"
268 "smull v11.8h, v3.8b, v4.8b\n"
269 "sadalp v28.4s, v12.8h\n"
270 "smull v12.8h, v0.8b, v5.8b\n"
271 "sadalp v29.4s, v13.8h\n"
272 "smull v13.8h, v1.8b, v5.8b\n"
273 "sadalp v30.4s, v14.8h\n"
274 "smull v14.8h, v2.8b, v5.8b\n"
275 "sadalp v31.4s, v15.8h\n"
276 "smull v15.8h, v3.8b, v5.8b\n"
277
278 // Multiply-accumulate second-half, again into the same
279 // 16bit local accumulator registers. This is where we
280 // take advantage of having int8 instead of uint8 and therefore
281 // being able to accumulate two products into int16.
282 "smlal2 v8.8h, v0.16b, v4.16b\n"
283 "smlal2 v9.8h, v1.16b, v4.16b\n"
284 "smlal2 v10.8h, v2.16b, v4.16b\n"
285 "smlal2 v11.8h, v3.16b, v4.16b\n"
286
287 "smlal2 v12.8h, v0.16b, v5.16b\n"
288 "smlal2 v13.8h, v1.16b, v5.16b\n"
289 "smlal2 v14.8h, v2.16b, v5.16b\n"
290 "smlal2 v15.8h, v3.16b, v5.16b\n"
291
292
293
294 // Each iteration of this loop advances by 16 levels of depth.
295 "add w1, w1, #16\n"
296
297 // Loop termination condition
298 "cmp w1, w12\n"
299
300 "blt 2b\n"
301
302 "79:\n"
303
304 "sadalp v16.4s, v8.8h\n"
305 "smull v8.8h, v0.8b, v6.8b\n"
306 "sadalp v17.4s, v9.8h\n"
307 "smull v9.8h, v1.8b, v6.8b\n"
308 "sadalp v18.4s, v10.8h\n"
309 "smull v10.8h, v2.8b, v6.8b\n"
310 "sadalp v19.4s, v11.8h\n"
311 "smull v11.8h, v3.8b, v6.8b\n"
312 "sadalp v20.4s, v12.8h\n"
313 "smull v12.8h, v0.8b, v7.8b\n"
314 "sadalp v21.4s, v13.8h\n"
315 "smull v13.8h, v1.8b, v7.8b\n"
316 "sadalp v22.4s, v14.8h\n"
317 "smull v14.8h, v2.8b, v7.8b\n"
318 "sadalp v23.4s, v15.8h\n"
319 "smull v15.8h, v3.8b, v7.8b\n"
320
321 // Multiply-accumulate second-half, again into the same
322 // 16bit local accumulator registers. This is where we
323 // take advantage of having int8 instead of uint8 and therefore
324 // being able to accumulate two products into int16.
325 "smlal2 v8.8h, v0.16b, v6.16b\n"
326 "smlal2 v9.8h, v1.16b, v6.16b\n"
327 "smlal2 v10.8h, v2.16b, v6.16b\n"
328 "smlal2 v11.8h, v3.16b, v6.16b\n"
329
330 "smlal2 v12.8h, v0.16b, v7.16b\n"
331 "smlal2 v13.8h, v1.16b, v7.16b\n"
332 "smlal2 v14.8h, v2.16b, v7.16b\n"
333 "smlal2 v15.8h, v3.16b, v7.16b\n"
334
335 "sadalp v24.4s, v8.8h\n"
336 "sadalp v25.4s, v9.8h\n"
337 "sadalp v26.4s, v10.8h\n"
338 "sadalp v27.4s, v11.8h\n"
339 "sadalp v28.4s, v12.8h\n"
340 "sadalp v29.4s, v13.8h\n"
341 "sadalp v30.4s, v14.8h\n"
342 "sadalp v31.4s, v15.8h\n"
343
344 // End of accumulation. The registers v16 -- v31 contain the final
345 // int32 accumulator values of the current 4x4 destination block.
346 // We now have to compute the final 8-bit values from these int32
347 // accumulators, and advance to the next 4x4 block. We intertwine
348 // these two aspects whenever possible for optimal pipelining, both
349 // at the data flow level (prefetch data for next block as early as
350 // possible) and instruction pipelining level (some of the next-block
351 // work can dual-issue with some of the final work on the current
352 // block).
353
354 // Reduce 32bit accumulators horizontally.
355 "addp v16.4s, v16.4s, v17.4s\n"
356 "addp v18.4s, v18.4s, v19.4s\n"
357 "addp v20.4s, v20.4s, v21.4s\n"
358 "addp v22.4s, v22.4s, v23.4s\n"
359 "addp v24.4s, v24.4s, v25.4s\n"
360 "addp v26.4s, v26.4s, v27.4s\n"
361 "addp v28.4s, v28.4s, v29.4s\n"
362 "addp v30.4s, v30.4s, v31.4s\n"
363
364 // Reduce 32bit accumulators horizontally, second pass
365 // (each pass adds pairwise. we need to add 4-wise).
366 "addp v16.4s, v16.4s, v18.4s\n"
367 "addp v17.4s, v20.4s, v22.4s\n"
368 "addp v18.4s, v24.4s, v26.4s\n"
369 "addp v19.4s, v28.4s, v30.4s\n"
370
371 // Logic to advance to the next block in preparation for the next
372 // iteration of the main loop. For now, we only want to compute
373 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
374 // not yet ready to update the values of row and col, as we still need
375 // the current values for the rest of the work on the current block.
376
377 "cmp %w[row], w7\n" // Have we finished the last row?
378 "bge 4f\n" // If finished last row, go to 4
379 // Not finished last row: then advance to next row.
380 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
381 "b 5f\n"
382 "4:\n" // Finished last row...
383 "mov %[lhs_col_ptr], x5\n" // Go back to first row
384 // Now we need to advance to the next column. If we already
385 // finished the last column, then in principle we are done, however
386 // we can't just return here, as we need to allow the end work of the
387 // current block to complete. The good news is that at this point it
388 // doesn't matter what data we load for the next column, since
389 // we will exit from the main loop below before actually storing
390 // anything computed from that data.
391 "cmp %w[col], w8\n" // Have we finished the last column?
392 "bge 5f\n" // If yes, just carry on without updating the column pointer.
393 // Not finished last column: then advance to next column.
394 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
395 "5:\n"
396
397 // Set the LHS and RHS data pointers to the start of the columns just
398 // computed.
399 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
400 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
401
402 // Load some parameters needed for the end work on current block.
403 "mvni v8.4s, #0\n"
404 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
405 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
406 "ins v13.h[4], w4\n" // dst_zero_point
407 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
408 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
409 "dup v9.4s, w3\n" // create prod_zp_depth_vec
410
411 // Now we load: bias data, LHS sums data, RHS sums data.
412
413 // First, load the base pointers from the params.
414 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
415
416 // Determine the channel index.
417 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
418 "csel w3, %w[row], %w[col], eq\n"
419
420 // Offset the bias pointer as needed given the current row, col.
421 "add x5, x1, x3, lsl #2\n"
422
423 // If there is no bias, use no offset, just address the passed zero
424 // data.
425 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
426 "csel x1, x1, x5, eq\n"
427
428 // Load 4 bias values.
429 "ld1 {v14.4s}, [x1]\n"
430
431 // Load the multiplier_fixedpoint values.
432 "add x5, x4, x3, lsl #2\n"
433 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
434 "csel x4, x4, x5, eq\n"
435 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
436
437 // Now that we know what LHS and RHS data the next iteration of the
438 // main loop will need to load, we start loading the first 32 bytes of
439 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
440 // in the rest of the work on the current block.
441 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
442 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
443 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
444 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
445 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
446 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
447 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
448 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
449
450 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
451 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
452 "add v14.4s, v14.4s, v9.4s\n"
453
454 // Perform the bias-addition (per the above, we have just folded into
455 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
456 // Jump based on channel dimension.
457 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
458 "bne 6f\n"
459 // Case where channels are rows
460 "add v16.4s, v16.4s, v14.4s\n"
461 "add v17.4s, v17.4s, v14.4s\n"
462 "add v18.4s, v18.4s, v14.4s\n"
463 "add v19.4s, v19.4s, v14.4s\n"
464 "b 7f\n"
465
466 "6:\n"
467 // Case where channels are columns
468 "dup v20.4s, v14.s[0]\n"
469 "dup v21.4s, v14.s[1]\n"
470 "dup v22.4s, v14.s[2]\n"
471 "dup v23.4s, v14.s[3]\n"
472 "add v16.4s, v16.4s, v20.4s\n"
473 "add v17.4s, v17.4s, v21.4s\n"
474 "add v18.4s, v18.4s, v22.4s\n"
475 "add v19.4s, v19.4s, v23.4s\n"
476 "7:\n"
477
478 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
479 "beq 401f\n"
480 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
481 "add x3, x3, %x[col], lsl #2\n"
482 "ld1 {v14.4s}, [x3]\n"
483 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
484 "dup v10.4s, w5\n" // create lhs_zero_point_vec
485 // Subtract rhs_sums * lhs_zero_point, per
486 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
487 "mls v16.4s, v10.4s, v14.s[0]\n"
488 "mls v17.4s, v10.4s, v14.s[1]\n"
489 "mls v18.4s, v10.4s, v14.s[2]\n"
490 "mls v19.4s, v10.4s, v14.s[3]\n"
491 "401:\n"
492
493 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
494 "beq 402f\n"
495 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
496 "add x2, x2, %x[row], lsl #2\n"
497 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
498 // Load 4 lhs_sums values.
499 "ld1 {v11.4s}, [x2]\n"
500 "ins v13.s[1], w5\n" // rhs_zero_point
501 // Compute lhs_sums * rhs_zero_point.
502 "mul v11.4s, v11.4s, v13.s[1]\n"
503 // Subtract lhs_sums * rhs_zero_point, per
504 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
505 "sub v16.4s, v16.4s, v11.4s\n"
506 "sub v17.4s, v17.4s, v11.4s\n"
507 "sub v18.4s, v18.4s, v11.4s\n"
508 "sub v19.4s, v19.4s, v11.4s\n"
509
510 // If the destination is int32, it means the user asks for the raw
511 // accumulators, no need for us to downquantize the value.
512 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
513 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
514
515 "402:\n"
516
517 // At this point we have computed the final int32 values. Now we
518 // start down-quantizing them to obtain the final 8bit values from them.
519
520 // As part of this down-quantization, our int32 values will be
521 // multiplied by a multiplier that has a fixed-point component and an
522 // exponent component.
523
524 //Load the exponent part of the multiplier.
525 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
526 // Determine the channel index.
527 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
528 "csel w3, %w[row], %w[col], eq\n"
529
530 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
531 "add x5, x1, x3, lsl #2\n"
532 "csel x1, x1, x5, eq\n"
533
534 "ld1 {v14.4s}, [x1]\n"
535
536 "smin v11.4s, v8.4s, v14.4s\n"
537 "sub v12.4s, v14.4s, v11.4s\n"
538
539 // Jump based on channel dimension.
540 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
541 "bne 8f\n"
542 // Case where channels are rows
543
544 // Apply the positive exponent part of the multiplier.
545 "sshl v16.4s, v16.4s, v12.4s\n"
546 "sshl v17.4s, v17.4s, v12.4s\n"
547 "sshl v18.4s, v18.4s, v12.4s\n"
548 "sshl v19.4s, v19.4s, v12.4s\n"
549
550 // Apply the fixed-point part of the multiplier.
551 "sqdmulh v16.4s, v16.4s, v15.4s\n"
552 "sqdmulh v17.4s, v17.4s, v15.4s\n"
553 "sqdmulh v18.4s, v18.4s, v15.4s\n"
554 "sqdmulh v19.4s, v19.4s, v15.4s\n"
555
556 // Apply the negative exponent part of the multiplier.
557 "srshl v16.4s, v16.4s, v11.4s\n"
558 "srshl v17.4s, v17.4s, v11.4s\n"
559 "srshl v18.4s, v18.4s, v11.4s\n"
560 "srshl v19.4s, v19.4s, v11.4s\n"
561 "b 9f\n"
562
563 "8:\n"
564 // Case where channels are columns
565
566 // Apply the positive exponent part of the multiplier.
567 "dup v20.4s, v12.s[0]\n"
568 "dup v21.4s, v12.s[1]\n"
569 "dup v22.4s, v12.s[2]\n"
570 "dup v23.4s, v12.s[3]\n"
571 "sshl v16.4s, v16.4s, v20.4s\n"
572 "sshl v17.4s, v17.4s, v21.4s\n"
573 "sshl v18.4s, v18.4s, v22.4s\n"
574 "sshl v19.4s, v19.4s, v23.4s\n"
575
576 // Apply the fixed-point part of the multiplier.
577 "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
578 "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
579 "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
580 "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
581
582 // Apply the negative exponent part of the multiplier.
583 "dup v20.4s, v11.s[0]\n"
584 "dup v21.4s, v11.s[1]\n"
585 "dup v22.4s, v11.s[2]\n"
586 "dup v23.4s, v11.s[3]\n"
587 "srshl v16.4s, v16.4s, v20.4s\n"
588 "srshl v17.4s, v17.4s, v21.4s\n"
589 "srshl v18.4s, v18.4s, v22.4s\n"
590 "srshl v19.4s, v19.4s, v23.4s\n"
591 "9:\n"
592
593 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
594 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
595 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
596 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
597
598 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
599
600 // Cast-and-saturate from int32 to int16
601 "sqxtn v16.4h, v16.4s\n"
602 "sqxtn2 v16.8h, v17.4s\n"
603 "sqxtn v17.4h, v18.4s\n"
604 "sqxtn2 v17.8h, v19.4s\n"
605
606 // At this point, v18 -- v31 aren't used anymore for the current block,
607 // so we can start clearing these accumulators for the next block
608 // (next iteration of the main loop).
609 RUY_MAKE_ZERO(v18)
610 RUY_MAKE_ZERO(v19)
611 RUY_MAKE_ZERO(v20)
612 RUY_MAKE_ZERO(v21)
613 RUY_MAKE_ZERO(v22)
614 RUY_MAKE_ZERO(v23)
615 RUY_MAKE_ZERO(v24)
616 RUY_MAKE_ZERO(v25)
617 RUY_MAKE_ZERO(v26)
618 RUY_MAKE_ZERO(v27)
619 RUY_MAKE_ZERO(v28)
620 RUY_MAKE_ZERO(v29)
621 RUY_MAKE_ZERO(v30)
622 RUY_MAKE_ZERO(v31)
623
624 // Add the destination zero point
625 "dup v14.8h, v13.h[4]\n"
626 "add v16.8h, v16.8h, v14.8h\n"
627 "add v17.8h, v17.8h, v14.8h\n"
628
629 // Cast-and-saturate from int16 to uint8
630 "sqxtun v16.8b, v16.8h\n"
631 "sqxtun2 v16.16b, v17.8h\n"
632
633 // Load the clamp_min, clamp_max bounds
634 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
635 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
636 "dup v14.16b, w2\n" // clamp_min
637 "dup v15.16b, w3\n" // clamp_max
638
639 // Apply the clamp_min bound
640 "umax v16.16b, v16.16b, v14.16b\n"
641 // Apply the clamp_max bound
642 "umin v16.16b, v16.16b, v15.16b\n"
643
644 // Compute how much of the 4x4 block of destination 8bit values that
645 // we have computed, fit in the destination matrix. Typically, all of
646 // it fits, but when the destination matrix shape is not a multiple
647 // of 4x4, there are some 4x4 blocks along the boundaries that do
648 // not fit entirely.
649 "sub w1, %w[dst_rows], %w[row]\n"
650 "sub w2, %w[dst_cols], %w[col]\n"
651 "mov w3, #4\n"
652 "cmp w1, #4\n"
653 // Compute w1 = how many rows of the 4x4 block fit
654 "csel w1, w1, w3, le\n"
655 "cmp w2, #4\n"
656 // Compute w2 = how many cols of the 4x4 block fit
657 "csel w2, w2, w3, le\n"
658
659 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
660 "cmp w1, w3\n"
661 "ccmp w2, w3, 0, eq\n"
662 "mov x4, %[dst_ptr]\n"
663 // Yes, all of the 4x4 block fits, go to fast path.
664 "beq 30f\n"
665 // Not all of the 4x4 block fits.
666 // Store to dst_tmp_buf
667 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
668 // Slow loop copying from dst_tmp_buf to dst.
669 "mov x3, %[dst_tmp_buf]\n"
670 "mov w6, #0\n"
671 "50:\n"
672 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
673 "mov w5, #0\n"
674 "51:\n"
675 "ldrb w7, [x3, w5, uxtw]\n"
676 "strb w7, [x4, w5, uxtw]\n"
677 "add w5, w5, #1\n"
678 "cmp w5, w1\n"
679 "blt 51b\n"
680 "add w6, w6, #1\n"
681 "add x3, x3, #4\n"
682 "add x4, x4, x11\n"
683 "cmp w6, w2\n"
684 "blt 50b\n"
685 "b 31f\n"
686 "30:\n"
687 // Yes, all of the 4x4 block fits.
688 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
689 "mov x3, x4\n"
690 "st1 {v16.b}[0], [x3], #1\n"
691 "add x4, x4, x11\n"
692 "st1 {v16.b}[1], [x3], #1\n"
693 "st1 {v16.b}[2], [x3], #1\n"
694 "st1 {v16.b}[3], [x3], #1\n"
695 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
696 "mov x3, x4\n"
697 "st1 {v16.b}[4], [x3], #1\n"
698 "add x4, x4, x11\n"
699 "st1 {v16.b}[5], [x3], #1\n"
700 "st1 {v16.b}[6], [x3], #1\n"
701 "st1 {v16.b}[7], [x3], #1\n"
702 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
703 "mov x3, x4\n"
704 "st1 {v16.b}[8], [x3], #1\n"
705 "add x4, x4, x11\n"
706 "st1 {v16.b}[9], [x3], #1\n"
707 "st1 {v16.b}[10], [x3], #1\n"
708 "st1 {v16.b}[11], [x3], #1\n"
709 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
710 "mov x3, x4\n"
711 "st1 {v16.b}[12], [x3], #1\n"
712 "add x4, x4, x11\n"
713 "st1 {v16.b}[13], [x3], #1\n"
714 "st1 {v16.b}[14], [x3], #1\n"
715 "st1 {v16.b}[15], [x3], #1\n"
716 "31:\n"
717
718 "add %[dst_ptr], %[dst_ptr], #4\n"
719
720 RUY_MAKE_ZERO(v16)
721 RUY_MAKE_ZERO(v17)
722
723 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
724
725 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
726
727 // Cast-and-saturate from int32 to int16
728 "sqxtn v16.4h, v16.4s\n"
729 "sqxtn2 v16.8h, v17.4s\n"
730 "sqxtn v17.4h, v18.4s\n"
731 "sqxtn2 v17.8h, v19.4s\n"
732
733 // At this point, v18 -- v31 aren't used anymore for the current block,
734 // so we can start clearing these accumulators for the next block
735 // (next iteration of the main loop).
736 RUY_MAKE_ZERO(v18)
737 RUY_MAKE_ZERO(v19)
738 RUY_MAKE_ZERO(v20)
739 RUY_MAKE_ZERO(v21)
740 RUY_MAKE_ZERO(v22)
741 RUY_MAKE_ZERO(v23)
742 RUY_MAKE_ZERO(v24)
743 RUY_MAKE_ZERO(v25)
744 RUY_MAKE_ZERO(v26)
745 RUY_MAKE_ZERO(v27)
746 RUY_MAKE_ZERO(v28)
747 RUY_MAKE_ZERO(v29)
748 RUY_MAKE_ZERO(v30)
749 RUY_MAKE_ZERO(v31)
750
751 // Add the destination zero point
752 "dup v14.8h, v13.h[4]\n"
753 "add v16.8h, v16.8h, v14.8h\n"
754 "add v17.8h, v17.8h, v14.8h\n"
755
756 // Cast-and-saturate from int16 to int8
757 "sqxtn v16.8b, v16.8h\n"
758 "sqxtn2 v16.16b, v17.8h\n"
759
760 // Load the clamp_min, clamp_max bounds
761 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
762 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
763 "dup v14.16b, w2\n" // clamp_min
764 "dup v15.16b, w3\n" // clamp_max
765
766 // Apply the clamp_min bound
767 "smax v16.16b, v16.16b, v14.16b\n"
768 // Apply the clamp_max bound
769 "smin v16.16b, v16.16b, v15.16b\n"
770
771 // Compute how much of the 4x4 block of destination 8bit values that
772 // we have computed, fit in the destination matrix. Typically, all of
773 // it fits, but when the destination matrix shape is not a multiple
774 // of 4x4, there are some 4x4 blocks along the boundaries that do
775 // not fit entirely.
776 "sub w1, %w[dst_rows], %w[row]\n"
777 "sub w2, %w[dst_cols], %w[col]\n"
778 "mov w3, #4\n"
779 "cmp w1, #4\n"
780 // Compute w1 = how many rows of the 4x4 block fit
781 "csel w1, w1, w3, le\n"
782 "cmp w2, #4\n"
783 // Compute w2 = how many cols of the 4x4 block fit
784 "csel w2, w2, w3, le\n"
785
786 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
787 "cmp w1, w3\n"
788 "ccmp w2, w3, 0, eq\n"
789 "mov x4, %[dst_ptr]\n"
790 // Yes, all of the 4x4 block fits, go to fast path.
791 "beq 30f\n"
792 // Not all of the 4x4 block fits.
793 // Store to dst_tmp_buf
794 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
795 // Slow loop copying from dst_tmp_buf to dst.
796 "mov x3, %[dst_tmp_buf]\n"
797 "mov w6, #0\n"
798 "50:\n"
799 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
800 "mov w5, #0\n"
801 "51:\n"
802 "ldrb w7, [x3, w5, uxtw]\n"
803 "strb w7, [x4, w5, uxtw]\n"
804 "add w5, w5, #1\n"
805 "cmp w5, w1\n"
806 "blt 51b\n"
807 "add w6, w6, #1\n"
808 "add x3, x3, #4\n"
809 "add x4, x4, x11\n"
810 "cmp w6, w2\n"
811 "blt 50b\n"
812 "b 31f\n"
813 "30:\n"
814 // Yes, all of the 4x4 block fits.
815 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
816 "mov x3, x4\n"
817 "st1 {v16.b}[0], [x3], #1\n"
818 "add x4, x4, x11\n"
819 "st1 {v16.b}[1], [x3], #1\n"
820 "st1 {v16.b}[2], [x3], #1\n"
821 "st1 {v16.b}[3], [x3], #1\n"
822 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
823 "mov x3, x4\n"
824 "st1 {v16.b}[4], [x3], #1\n"
825 "add x4, x4, x11\n"
826 "st1 {v16.b}[5], [x3], #1\n"
827 "st1 {v16.b}[6], [x3], #1\n"
828 "st1 {v16.b}[7], [x3], #1\n"
829 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
830 "mov x3, x4\n"
831 "st1 {v16.b}[8], [x3], #1\n"
832 "add x4, x4, x11\n"
833 "st1 {v16.b}[9], [x3], #1\n"
834 "st1 {v16.b}[10], [x3], #1\n"
835 "st1 {v16.b}[11], [x3], #1\n"
836 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
837 "mov x3, x4\n"
838 "st1 {v16.b}[12], [x3], #1\n"
839 "add x4, x4, x11\n"
840 "st1 {v16.b}[13], [x3], #1\n"
841 "st1 {v16.b}[14], [x3], #1\n"
842 "st1 {v16.b}[15], [x3], #1\n"
843 "31:\n"
844
845 "add %[dst_ptr], %[dst_ptr], #4\n"
846
847 RUY_MAKE_ZERO(v16)
848 RUY_MAKE_ZERO(v17)
849
850 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
851
852 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
853
854 // Add the destination zero point
855 "dup v14.4h, v13.h[4]\n"
856 "saddw v16.4s, v16.4s, v14.4h\n"
857 "saddw v17.4s, v17.4s, v14.4h\n"
858 "saddw v18.4s, v18.4s, v14.4h\n"
859 "saddw v19.4s, v19.4s, v14.4h\n"
860
861 // Cast-and-saturate from int32 to int16
862 "sqxtn v16.4h, v16.4s\n"
863 "sqxtn2 v16.8h, v17.4s\n"
864 "sqxtn v17.4h, v18.4s\n"
865 "sqxtn2 v17.8h, v19.4s\n"
866
867 // At this point, v18 -- v31 aren't used anymore for the current block,
868 // so we can start clearing these accumulators for the next block
869 // (next iteration of the main loop).
870 RUY_MAKE_ZERO(v18)
871 RUY_MAKE_ZERO(v19)
872 RUY_MAKE_ZERO(v20)
873 RUY_MAKE_ZERO(v21)
874 RUY_MAKE_ZERO(v22)
875 RUY_MAKE_ZERO(v23)
876 RUY_MAKE_ZERO(v24)
877 RUY_MAKE_ZERO(v25)
878 RUY_MAKE_ZERO(v26)
879 RUY_MAKE_ZERO(v27)
880 RUY_MAKE_ZERO(v28)
881 RUY_MAKE_ZERO(v29)
882 RUY_MAKE_ZERO(v30)
883 RUY_MAKE_ZERO(v31)
884
885 // Load the clamp_min, clamp_max bounds
886 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
887 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
888 "dup v14.8h, w2\n" // clamp_min
889 "dup v15.8h, w3\n" // clamp_max
890
891 // Apply the clamp_min bound
892 "smax v16.8h, v16.8h, v14.8h\n"
893 "smax v17.8h, v17.8h, v14.8h\n"
894 // Apply the clamp_max bound
895 "smin v16.8h, v16.8h, v15.8h\n"
896 "smin v17.8h, v17.8h, v15.8h\n"
897
898 // Compute how much of the 4x4 block of destination 8bit values that
899 // we have computed, fit in the destination matrix. Typically, all of
900 // it fits, but when the destination matrix shape is not a multiple
901 // of 4x4, there are some 4x4 blocks along the boundaries that do
902 // not fit entirely.
903 "sub w1, %w[dst_rows], %w[row]\n"
904 "sub w2, %w[dst_cols], %w[col]\n"
905 "mov w3, #4\n"
906 "cmp w1, #4\n"
907 // Compute w1 = how many rows of the 4x4 block fit
908 "csel w1, w1, w3, le\n"
909 "cmp w2, #4\n"
910 // Compute w2 = how many cols of the 4x4 block fit
911 "csel w2, w2, w3, le\n"
912
913 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
914 "cmp w1, w3\n"
915 "ccmp w2, w3, 0, eq\n"
916 "mov x4, %[dst_ptr]\n"
917 // Yes, all of the 4x4 block fits, go to fast path.
918 "beq 30f\n"
919 // Not all of the 4x4 block fits.
920 // Store to dst_tmp_buf
921 "str q16, [%[dst_tmp_buf], #0]\n"
922 "str q17, [%[dst_tmp_buf], #16]\n"
923 // Slow loop copying from dst_tmp_buf to dst.
924 "mov x3, %[dst_tmp_buf]\n"
925 "mov w6, #0\n"
926 "50:\n"
927 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
928 "mov w5, #0\n"
929 "51:\n"
930 "ldrh w7, [x3, x5, lsl #1]\n"
931 "strh w7, [x4, x5, lsl #1]\n"
932 "add w5, w5, #1\n"
933 "cmp w5, w1\n"
934 "blt 51b\n"
935 "add w6, w6, #1\n"
936 "add x3, x3, #8\n"
937 "add x4, x4, x11\n"
938 "cmp w6, w2\n"
939 "blt 50b\n"
940 "b 31f\n"
941 "30:\n"
942 // Yes, all of the 4x4 block fits.
943 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
944 "mov x3, x4\n"
945 "st1 {v16.h}[0], [x3], #2\n"
946 "add x4, x4, x11\n"
947 "st1 {v16.h}[1], [x3], #2\n"
948 "st1 {v16.h}[2], [x3], #2\n"
949 "st1 {v16.h}[3], [x3], #2\n"
950 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
951 "mov x3, x4\n"
952 "st1 {v16.h}[4], [x3], #2\n"
953 "add x4, x4, x11\n"
954 "st1 {v16.h}[5], [x3], #2\n"
955 "st1 {v16.h}[6], [x3], #2\n"
956 "st1 {v16.h}[7], [x3], #2\n"
957 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
958 "mov x3, x4\n"
959 "st1 {v17.h}[0], [x3], #2\n"
960 "add x4, x4, x11\n"
961 "st1 {v17.h}[1], [x3], #2\n"
962 "st1 {v17.h}[2], [x3], #2\n"
963 "st1 {v17.h}[3], [x3], #2\n"
964 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
965 "mov x3, x4\n"
966 "st1 {v17.h}[4], [x3], #2\n"
967 "add x4, x4, x11\n"
968 "st1 {v17.h}[5], [x3], #2\n"
969 "st1 {v17.h}[6], [x3], #2\n"
970 "st1 {v17.h}[7], [x3], #2\n"
971 "31:\n"
972
973 "add %[dst_ptr], %[dst_ptr], #8\n"
974
975 RUY_MAKE_ZERO(v16)
976 RUY_MAKE_ZERO(v17)
977
978 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
979
980 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
981
982 // Since the store type is the same as the accum type, no need for
983 // downcast. There's also no need for clamp by min/max.
984
985 // At this point, v20 -- v31 aren't used anymore for the current block,
986 // so we can start clearing these accumulators for the next block
987 // (next iteration of the main loop).
988 RUY_MAKE_ZERO(v20)
989 RUY_MAKE_ZERO(v21)
990 RUY_MAKE_ZERO(v22)
991 RUY_MAKE_ZERO(v23)
992 RUY_MAKE_ZERO(v24)
993 RUY_MAKE_ZERO(v25)
994 RUY_MAKE_ZERO(v26)
995 RUY_MAKE_ZERO(v27)
996 RUY_MAKE_ZERO(v28)
997 RUY_MAKE_ZERO(v29)
998 RUY_MAKE_ZERO(v30)
999 RUY_MAKE_ZERO(v31)
1000
1001 // Compute how much of the 4x4 block of destination 8bit values that
1002 // we have computed, fit in the destination matrix. Typically, all of
1003 // it fits, but when the destination matrix shape is not a multiple
1004 // of 4x4, there are some 4x4 blocks along the boundaries that do
1005 // not fit entirely.
1006 "sub w1, %w[dst_rows], %w[row]\n"
1007 "sub w2, %w[dst_cols], %w[col]\n"
1008 "mov w3, #4\n"
1009 "cmp w1, #4\n"
1010 // Compute w1 = how many rows of the 4x4 block fit
1011 "csel w1, w1, w3, le\n"
1012 "cmp w2, #4\n"
1013 // Compute w2 = how many cols of the 4x4 block fit
1014 "csel w2, w2, w3, le\n"
1015
1016 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
1017 "cmp w1, w3\n"
1018 "ccmp w2, w3, 0, eq\n"
1019 "mov x4, %[dst_ptr]\n"
1020 // Yes, all of the 4x4 block fits, go to fast path.
1021 "beq 30f\n"
1022 // Not all of the 4x4 block fits.
1023 // Store to dst_tmp_buf
1024 "str q16, [%[dst_tmp_buf], #0]\n"
1025 "str q17, [%[dst_tmp_buf], #16]\n"
1026 "str q18, [%[dst_tmp_buf], #32]\n"
1027 "str q19, [%[dst_tmp_buf], #48]\n"
1028 // Slow loop copying from dst_tmp_buf to dst.
1029 "mov x3, %[dst_tmp_buf]\n"
1030 "mov w6, #0\n"
1031 "50:\n"
1032 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1033 "mov w5, #0\n"
1034 "51:\n"
1035 "ldr w7, [x3, x5, lsl #2]\n"
1036 "str w7, [x4, x5, lsl #2]\n"
1037 "add w5, w5, #1\n"
1038 "cmp w5, w1\n"
1039 "blt 51b\n"
1040 "add w6, w6, #1\n"
1041 "add x3, x3, #16\n"
1042 "add x4, x4, x11\n"
1043 "cmp w6, w2\n"
1044 "blt 50b\n"
1045 "b 31f\n"
1046 "30:\n"
1047 // Yes, all of the 4x4 block fits.
1048 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1049 "mov x3, x4\n"
1050 "st1 {v16.s}[0], [x3], #4\n"
1051 "add x4, x4, x11\n"
1052 "st1 {v16.s}[1], [x3], #4\n"
1053 "st1 {v16.s}[2], [x3], #4\n"
1054 "st1 {v16.s}[3], [x3], #4\n"
1055 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1056 "mov x3, x4\n"
1057 "st1 {v17.s}[0], [x3], #4\n"
1058 "add x4, x4, x11\n"
1059 "st1 {v17.s}[1], [x3], #4\n"
1060 "st1 {v17.s}[2], [x3], #4\n"
1061 "st1 {v17.s}[3], [x3], #4\n"
1062 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1063 "mov x3, x4\n"
1064 "st1 {v18.s}[0], [x3], #4\n"
1065 "add x4, x4, x11\n"
1066 "st1 {v18.s}[1], [x3], #4\n"
1067 "st1 {v18.s}[2], [x3], #4\n"
1068 "st1 {v18.s}[3], [x3], #4\n"
1069 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1070 "mov x3, x4\n"
1071 "st1 {v19.s}[0], [x3], #4\n"
1072 "add x4, x4, x11\n"
1073 "st1 {v19.s}[1], [x3], #4\n"
1074 "st1 {v19.s}[2], [x3], #4\n"
1075 "st1 {v19.s}[3], [x3], #4\n"
1076 "31:\n"
1077
1078 "add %[dst_ptr], %[dst_ptr], #16\n"
1079
1080 RUY_MAKE_ZERO(v16)
1081 RUY_MAKE_ZERO(v17)
1082 RUY_MAKE_ZERO(v18)
1083 RUY_MAKE_ZERO(v19)
1084
1085 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1086
1087 // For the next block: perform the first few multiply-adds on the data
1088 // that we have already loaded.
1089 "smull v8.8h, v0.8b, v4.8b\n"
1090 "smull v9.8h, v1.8b, v4.8b\n"
1091 "smull v10.8h, v2.8b, v4.8b\n"
1092 "smull v11.8h, v3.8b, v4.8b\n"
1093 "smull v12.8h, v0.8b, v5.8b\n"
1094 "smull v13.8h, v1.8b, v5.8b\n"
1095 "smull v14.8h, v2.8b, v5.8b\n"
1096 "smull v15.8h, v3.8b, v5.8b\n"
1097 "smlal2 v8.8h, v0.16b, v4.16b\n"
1098 "smlal2 v9.8h, v1.16b, v4.16b\n"
1099 "smlal2 v10.8h, v2.16b, v4.16b\n"
1100 "smlal2 v11.8h, v3.16b, v4.16b\n"
1101 "smlal2 v12.8h, v0.16b, v5.16b\n"
1102 "smlal2 v13.8h, v1.16b, v5.16b\n"
1103 "smlal2 v14.8h, v2.16b, v5.16b\n"
1104 "smlal2 v15.8h, v3.16b, v5.16b\n"
1105
1106 // Reload some params --- we had used x5 -- x7 for a few other things
1107 // since the last time we had loaded them.
1108 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1109 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1110 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1111
1112 // Move to the next block of the destination matrix, for the next iter
1113 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1114 // been updated earlier.
1115 // Have we reached the end row?
1116 "cmp %w[row], w7\n"
1117 "beq 20f\n" // yes, end row.
1118 // Not end row. Move to the next row.
1119 "add %w[row], %w[row], #4\n"
1120 "b 21f\n"
1121 "20:\n"
1122 // Was already at end row.
1123 "mov %w[row], w6\n" // Move back to first row.
1124 "add %w[col], %w[col], #4\n" // Move to the next column.
1125 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
1126 "mov %[dst_ptr], %[dst_col_ptr]\n"
1127 "21:\n"
1128
1129 // Main loop exit condition: have we hit the end column?
1130 "cmp %w[col], w8\n"
1131
1132 // w1 is the number of levels of depth that we have already loaded
1133 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1134 // above, this is currently 4.
1135 "mov w1, #16\n"
1136
1137 "ble 1b\n"
1138
1139 // clang-format on
1140
1141 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
1142 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1143 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
1144 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
1145 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
1146 [dst_type_id] "r"(params.dst_type_id)
1147 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
1148 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1149 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
1150 "v26", "v27", "v28", "v29", "v30", "v31");
1151 }
1152
1153 // Similar to existing Kernel8bitNeon but specialized for the case of
1154 // RHS cols == 1.
1155 // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
1156 // since these are 64-bit, out-of-order and without dotprod support.
Kernel8bitNeon1Col(const KernelParams8bit<4,4> & params)1157 void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
1158 profiler::ScopeLabel label("Kernel (kNeon)");
1159
1160 CheckOffsetsInKernelParams8bit(params);
1161
1162 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1163 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1164 const std::int8_t* lhs_ptr = lhs_col_ptr;
1165 const std::int8_t* rhs_ptr = rhs_col_ptr;
1166 void* dst_col_ptr = params.dst_base_ptr;
1167 void* dst_ptr = dst_col_ptr;
1168 int row = params.start_row;
1169 int col = params.start_col;
1170
1171 RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
1172
1173 // The asm kernel below has the following NEON register allocation:
1174 //
1175 // v16 -- v19 are int32 accumulators.
1176 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
1177 // v4 from RHS:
1178 //
1179 // int8 RHS 16x1 block
1180 // /-----------|
1181 // |v4.b[0] |
1182 // | ... |
1183 // |v4.b[15] |
1184 // \-----------/
1185 // int8 LHS 4x16 block
1186 // /---------------------\ /-----------|
1187 // |v0.b[0] ... v0.b[15] | |v16.4s |
1188 // |v1.b[0] ... v1.b[15] | |v17.4s |
1189 // |v2.b[0] ... v2.b[15] | |v18.4s |
1190 // |v3.b[0] ... v3.b[15] | |v19.4s |
1191 // \---------------------/ \-----------/
1192 // int32 accumulators 4x1 block
1193 //
1194 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1195 // optimization for this kernel.
1196 asm volatile(
1197 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
1198
1199 // clang-format off
1200
1201 // Load some parameters into registers.
1202 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1203 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1204 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1205 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1206 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1207 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1208 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1209 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1210
1211 // Load the first 64 bytes of LHS and RHS data.
1212 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1213 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1214 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1215 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1216 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1217 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1218
1219 // Clear accumulators.
1220 RUY_MAKE_ZERO(v16)
1221 RUY_MAKE_ZERO(v17)
1222 RUY_MAKE_ZERO(v18)
1223 RUY_MAKE_ZERO(v19)
1224
1225 // w1 is the number of levels of depth that we have already loaded
1226 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1227 // above, this is currently 16.
1228 "mov w1, #16\n"
1229
1230 // Perform the first few multiply-adds on the data that we have already
1231 // loaded.
1232 "smull v8.8h, v0.8b, v4.8b\n"
1233 "smull v9.8h, v1.8b, v4.8b\n"
1234 "smull v10.8h, v2.8b, v4.8b\n"
1235 "smull v11.8h, v3.8b, v4.8b\n"
1236
1237 // Multiply-accumulate second-half, again into the same
1238 // 16bit local accumulator registers. This is where we
1239 // take advantage of having int8 instead of uint8 and therefore
1240 // being able to accumulate two products into int16.
1241 "smlal2 v8.8h, v0.16b, v4.16b\n"
1242 "smlal2 v9.8h, v1.16b, v4.16b\n"
1243 "smlal2 v10.8h, v2.16b, v4.16b\n"
1244 "smlal2 v11.8h, v3.16b, v4.16b\n"
1245
1246 // Main loop of the whole GEMM, over rows and columns of the
1247 // destination matrix.
1248 "1:\n"
1249
1250 // Reminder - w1 is how many levels of depth we have already loaded
1251 // data for, w12 is the total depth.
1252 "cmp w1, w12\n"
1253 "beq 79f\n"
1254
1255 "2:\n"
1256
1257 // Some multiplications and 16-bit accumulation were already done above,
1258 // so we start right away in the middle.
1259 "sadalp v16.4s, v8.8h\n"
1260 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1261 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1262 "sadalp v17.4s, v9.8h\n"
1263 "sadalp v18.4s, v10.8h\n"
1264 "sadalp v19.4s, v11.8h\n"
1265
1266 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1267 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1268 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1269 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1270
1271 "smull v8.8h, v0.8b, v4.8b\n"
1272 "smull v9.8h, v1.8b, v4.8b\n"
1273 "smull v10.8h, v2.8b, v4.8b\n"
1274 "smull v11.8h, v3.8b, v4.8b\n"
1275
1276 // Multiply-accumulate second-half, again into the same
1277 // 16bit local accumulator registers. This is where we
1278 // take advantage of having int8 instead of uint8 and therefore
1279 // being able to accumulate two products into int16.
1280 "smlal2 v8.8h, v0.16b, v4.16b\n"
1281 "smlal2 v9.8h, v1.16b, v4.16b\n"
1282 "smlal2 v10.8h, v2.16b, v4.16b\n"
1283 "smlal2 v11.8h, v3.16b, v4.16b\n"
1284
1285 // Each iteration of this loop advances by 16 levels of depth.
1286 "add w1, w1, #16\n"
1287
1288 // Loop termination condition
1289 "cmp w1, w12\n"
1290
1291 "blt 2b\n"
1292
1293 "79:\n"
1294
1295 "sadalp v16.4s, v8.8h\n"
1296 "sadalp v17.4s, v9.8h\n"
1297 "sadalp v18.4s, v10.8h\n"
1298 "sadalp v19.4s, v11.8h\n"
1299
1300 // End of accumulation. The registers v16 -- v19 contain the final
1301 // int32 accumulator values of the current 4x1 destination block.
1302 // We now have to compute the final 8-bit values from these int32
1303 // accumulators, and advance to the next 4x1 block. We intertwine
1304 // these two aspects whenever possible for optimal pipelining, both
1305 // at the data flow level (prefetch data for next block as early as
1306 // possible) and instruction pipelining level (some of the next-block
1307 // work can dual-issue with some of the final work on the current
1308 // block).
1309
1310 // Reduce 32bit accumulators horizontally.
1311 "addp v16.4s, v16.4s, v17.4s\n"
1312 "addp v18.4s, v18.4s, v19.4s\n"
1313
1314 // Reduce 32bit accumulators horizontally, second pass
1315 // (each pass adds pairwise. we need to add 4-wise).
1316 "addp v16.4s, v16.4s, v18.4s\n"
1317
1318 // Logic to advance to the next block in preparation for the next
1319 // iteration of the main loop. For now, we only want to compute
1320 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1321 // not yet ready to update the values of row and col, as we still need
1322 // the current values for the rest of the work on the current block.
1323
1324 "cmp %w[row], w7\n" // Have we finished the last row?
1325 "bge 4f\n" // If finished last row, go to 4
1326 // Not finished last row: then advance to next row.
1327 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
1328 "b 5f\n"
1329 "4:\n" // Finished last row...
1330 "mov %[lhs_col_ptr], x5\n" // Go back to first row
1331 // Now we need to advance to the next column. If we already
1332 // finished the last column, then in principle we are done, however
1333 // we can't just return here, as we need to allow the end work of the
1334 // current block to complete. The good news is that at this point it
1335 // doesn't matter what data we load for the next column, since
1336 // we will exit from the main loop below before actually storing
1337 // anything computed from that data.
1338 "cmp %w[col], w8\n" // Have we finished the last column?
1339 "bge 5f\n" // If yes, just carry on without updating the column pointer.
1340 // Not finished last column: then advance to next column.
1341 // (still multiply column stride by 4 due to packing)
1342 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
1343 "5:\n"
1344
1345 // Set the LHS and RHS data pointers to the start of the columns just
1346 // computed.
1347 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
1348 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
1349
1350 // Load some parameters needed for the end work on current block.
1351 "mvni v8.4s, #0\n"
1352 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1353 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1354 "ins v13.h[4], w4\n" // dst_zero_point
1355 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1356 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1357 "dup v9.4s, w3\n" // create prod_zp_depth_vec
1358 "add x5, x4, %x[row], lsl #2\n"
1359 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1360 "csel x4, x4, x5, eq\n"
1361
1362 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
1363
1364 // Now we load: bias data, LHS sums data, RHS sums data.
1365
1366 // First, load the base pointers from the params.
1367 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1368
1369 "add x5, x1, %x[row], lsl #2\n"
1370 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1371 "csel x1, x1, x5, eq\n"
1372
1373 // Load 4 bias values.
1374 "ld1 {v14.4s}, [x1]\n"
1375
1376 // Now that we know what LHS and RHS data the next iteration of the
1377 // main loop will need to load, we start loading the first 32 bytes of
1378 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1379 // in the rest of the work on the current block.
1380 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1381 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1382 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1383 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1384 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1385 "add %[rhs_ptr], %[rhs_ptr], #48\n"
1386
1387 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
1388 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1389 "add v14.4s, v14.4s, v9.4s\n"
1390
1391 // Perform the bias-addition (per the above, we have just folded into
1392 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1393 // (all four 32-bit accumulators are in v16 at this point)
1394 "add v16.4s, v16.4s, v14.4s\n"
1395
1396 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1397 "beq 401f\n"
1398 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1399 "add x3, x3, %x[col], lsl #2\n"
1400 "ld1 {v14.4s}, [x3]\n"
1401 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1402 "dup v10.4s, w5\n" // create lhs_zero_point_vec
1403 // Subtract rhs_sums * lhs_zero_point, per
1404 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1405 "mls v16.4s, v10.4s, v14.s[0]\n"
1406 "401:\n"
1407
1408 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1409 "beq 402f\n"
1410 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1411 "add x2, x2, %x[row], lsl #2\n"
1412 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1413 // Load 4 lhs_sums values.
1414 "ld1 {v11.4s}, [x2]\n"
1415 "ins v13.s[1], w5\n" // rhs_zero_point
1416 // Compute lhs_sums * rhs_zero_point.
1417 "mul v11.4s, v11.4s, v13.s[1]\n"
1418 // Subtract lhs_sums * rhs_zero_point, per
1419 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1420 "sub v16.4s, v16.4s, v11.4s\n"
1421
1422 // If the destination is int32, it means the user asks for the raw
1423 // accumulators, no need for us to downquantize the value.
1424 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1425 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1426
1427 "402:\n"
1428
1429 // At this point we have computed the final int32 values. Now we
1430 // start down-quantizing them to obtain the final 8bit values from them.
1431
1432 // As part of this down-quantization, our int32 values will be
1433 // multiplied by a multiplier that has a fixed-point component and an
1434 // exponent component.
1435
1436 //Load the exponent part of the multiplier.
1437 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1438 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1439 "add x5, x1, %x[row], lsl #2\n"
1440 "csel x1, x1, x5, eq\n"
1441
1442 "ld1 {v14.4s}, [x1]\n"
1443
1444 "smin v11.4s, v8.4s, v14.4s\n"
1445 "sub v12.4s, v14.4s, v11.4s\n"
1446
1447 // Apply the positive exponent part of the multiplier.
1448 "sshl v16.4s, v16.4s, v12.4s\n"
1449
1450 // Apply the fixed-point part of the multiplier.
1451 "sqdmulh v16.4s, v16.4s, v15.4s\n"
1452
1453 // Apply the negative exponent part of the multiplier.
1454 "srshl v16.4s, v16.4s, v11.4s\n"
1455
1456 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1457 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1458 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1459 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1460
1461 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1462
1463 // Cast-and-saturate from int32 to int16
1464 // After this instruction, all data is in lower half (64-bits) of v16
1465 "sqxtn v16.4h, v16.4s\n"
1466
1467 // At this point, v18 -- v31 aren't used anymore for the current block,
1468 // so we can start clearing these accumulators for the next block
1469 // (next iteration of the main loop).
1470 RUY_MAKE_ZERO(v18)
1471 RUY_MAKE_ZERO(v19)
1472
1473 // Add the destination zero point
1474 "dup v14.8h, v13.h[4]\n"
1475 "add v16.8h, v16.8h, v14.8h\n"
1476
1477 // Cast-and-saturate from int16 to uint8
1478 // Now all data is in the first 32-bits of v16
1479 "sqxtun v16.8b, v16.8h\n"
1480
1481 // Load the clamp_min, clamp_max bounds
1482 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1483 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1484 "dup v14.16b, w2\n" // clamp_min
1485 "dup v15.16b, w3\n" // clamp_max
1486
1487 // Apply the clamp_min bound
1488 "umax v16.16b, v16.16b, v14.16b\n"
1489 // Apply the clamp_max bound
1490 "umin v16.16b, v16.16b, v15.16b\n"
1491
1492 // Compute how much of the 4x1 block of destination 8bit values that
1493 // we have computed, fit in the destination matrix. Typically, all of
1494 // it fits, but when the destination matrix shape is not a multiple
1495 // of 4x1, there are some 4x1 blocks along the boundaries that do
1496 // not fit entirely.
1497 "sub w1, %w[dst_rows], %w[row]\n"
1498 "mov w3, #4\n"
1499 "cmp w1, #4\n"
1500 // Compute w1 = how many rows of the 4x1 block fit
1501 "csel w1, w1, w3, le\n"
1502
1503 // Test if w1==4, i.e. if all of the 4x1 block fits.
1504 "cmp w1, w3\n"
1505
1506 "mov x4, %[dst_ptr]\n"
1507 // Yes, all of the 4x1 block fits, go to fast path.
1508 "beq 30f\n"
1509 // Not all of the 4x1 block fits.
1510 // Store to dst_tmp_buf
1511 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
1512 // Slow loop copying from dst_tmp_buf to dst.
1513 "mov x3, %[dst_tmp_buf]\n"
1514 "mov w6, #0\n"
1515 "50:\n"
1516 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1517 "mov w5, #0\n"
1518 "51:\n"
1519 "ldrb w7, [x3, w5, uxtw]\n"
1520 "strb w7, [x4, w5, uxtw]\n"
1521 "add w5, w5, #1\n"
1522 "cmp w5, w1\n"
1523 "blt 51b\n"
1524 "b 31f\n"
1525 "30:\n"
1526 // Yes, all of the 4x1 block fits.
1527 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1528 "mov x3, x4\n"
1529 "st1 {v16.b}[0], [x3], #1\n"
1530 "st1 {v16.b}[1], [x3], #1\n"
1531 "st1 {v16.b}[2], [x3], #1\n"
1532 "st1 {v16.b}[3], [x3], #1\n"
1533 "31:\n"
1534
1535 "add %[dst_ptr], %[dst_ptr], #4\n"
1536
1537 RUY_MAKE_ZERO(v16)
1538 RUY_MAKE_ZERO(v17)
1539
1540 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1541
1542 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1543
1544 // Cast-and-saturate from int32 to int16
1545 // After this, all values for output are in the lower half (64 bits) of v16.
1546 "sqxtn v16.4h, v16.4s\n"
1547
1548 // At this point, v18 -- v31 aren't used anymore for the current block,
1549 // so we can start clearing these accumulators for the next block
1550 // (next iteration of the main loop).
1551 RUY_MAKE_ZERO(v18)
1552 RUY_MAKE_ZERO(v19)
1553
1554 // Add the destination zero point
1555 "dup v14.8h, v13.h[4]\n"
1556 "add v16.8h, v16.8h, v14.8h\n"
1557
1558 // Cast-and-saturate from int16 to int8
1559 "sqxtn v16.8b, v16.8h\n"
1560 // At this point, we only need 4 lowest 8-bit values in v16.
1561
1562 // Load the clamp_min, clamp_max bounds
1563 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1564 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1565 "dup v14.16b, w2\n" // clamp_min
1566 "dup v15.16b, w3\n" // clamp_max
1567
1568 // Apply the clamp_min bound
1569 "smax v16.16b, v16.16b, v14.16b\n"
1570 // Apply the clamp_max bound
1571 "smin v16.16b, v16.16b, v15.16b\n"
1572
1573 // Compute how much of the 4x4 block of destination 8bit values that
1574 // we have computed, fit in the destination matrix. Typically, all of
1575 // it fits, but when the destination matrix shape is not a multiple
1576 // of 4x4, there are some 4x4 blocks along the boundaries that do
1577 // not fit entirely.
1578 "sub w1, %w[dst_rows], %w[row]\n"
1579 "sub w2, %w[dst_cols], %w[col]\n"
1580 "mov w3, #4\n"
1581 "cmp w1, #4\n"
1582 // Compute w1 = how many rows of the 4x1 block fit
1583 "csel w1, w1, w3, le\n"
1584 "cmp w2, #4\n"
1585
1586 // Test if w1==4, i.e. if all of the 4x1 block fits.
1587 "cmp w1, w3\n"
1588 "ccmp w2, w3, 0, eq\n"
1589 "mov x4, %[dst_ptr]\n"
1590 // Yes, all of the 4x1 block fits, go to fast path.
1591 "beq 30f\n"
1592 // Not all of the 4x4 block fits.
1593 // Store to dst_tmp_buf
1594 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
1595 // Slow loop copying from dst_tmp_buf to dst.
1596 "mov x3, %[dst_tmp_buf]\n"
1597 "mov w6, #0\n"
1598 "50:\n"
1599 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1600 "mov w5, #0\n"
1601 "51:\n"
1602 "ldrb w7, [x3, w5, uxtw]\n"
1603 "strb w7, [x4, w5, uxtw]\n"
1604 "add w5, w5, #1\n"
1605 "cmp w5, w1\n"
1606 "blt 51b\n"
1607 "b 31f\n"
1608 "30:\n"
1609 // Yes, all of the 4x4 block fits.
1610 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1611 "mov x3, x4\n"
1612 "st1 {v16.b}[0], [x3], #1\n"
1613 "st1 {v16.b}[1], [x3], #1\n"
1614 "st1 {v16.b}[2], [x3], #1\n"
1615 "st1 {v16.b}[3], [x3], #1\n"
1616 "31:\n"
1617
1618 "add %[dst_ptr], %[dst_ptr], #4\n"
1619
1620 RUY_MAKE_ZERO(v16)
1621 RUY_MAKE_ZERO(v17)
1622
1623 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1624
1625 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1626
1627 // Add the destination zero point
1628 "dup v14.4h, v13.h[4]\n"
1629 "saddw v16.4s, v16.4s, v14.4h\n"
1630
1631 // Cast-and-saturate from int32 to int16
1632 // After this instruction, all data is in lower half of v16.
1633 "sqxtn v16.4h, v16.4s\n"
1634
1635 // At this point, v18 -- v31 aren't used anymore for the current block,
1636 // so we can start clearing these accumulators for the next block
1637 // (next iteration of the main loop).
1638 RUY_MAKE_ZERO(v18)
1639 RUY_MAKE_ZERO(v19)
1640
1641 // Load the clamp_min, clamp_max bounds
1642 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1643 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1644 "dup v14.8h, w2\n" // clamp_min
1645 "dup v15.8h, w3\n" // clamp_max
1646
1647 // Apply the clamp_min bound
1648 "smax v16.8h, v16.8h, v14.8h\n"
1649 // Apply the clamp_max bound
1650 "smin v16.8h, v16.8h, v15.8h\n"
1651
1652 // Compute how much of the 4x4 block of destination 8bit values that
1653 // we have computed, fit in the destination matrix. Typically, all of
1654 // it fits, but when the destination matrix shape is not a multiple
1655 // of 4x4, there are some 4x4 blocks along the boundaries that do
1656 // not fit entirely.
1657 "sub w1, %w[dst_rows], %w[row]\n"
1658 "sub w2, %w[dst_cols], %w[col]\n"
1659 "mov w3, #4\n"
1660 "cmp w1, #4\n"
1661 // Compute w1 = how many rows of the 4x4 block fit
1662 "csel w1, w1, w3, le\n"
1663 "cmp w2, #4\n"
1664
1665 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
1666 "cmp w1, w3\n"
1667 "mov x4, %[dst_ptr]\n"
1668 // Yes, all of the 4x4 block fits, go to fast path.
1669 "beq 30f\n"
1670 // Not all of the 4x4 block fits.
1671 // Store to dst_tmp_buf
1672 "str q16, [%[dst_tmp_buf], #0]\n"
1673 // Slow loop copying from dst_tmp_buf to dst.
1674 "mov x3, %[dst_tmp_buf]\n"
1675 "mov w6, #0\n"
1676 "50:\n"
1677 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1678 "mov w5, #0\n"
1679 "51:\n"
1680 "ldrh w7, [x3, x5, lsl #1]\n"
1681 "strh w7, [x4, x5, lsl #1]\n"
1682 "add w5, w5, #1\n"
1683 "cmp w5, w1\n"
1684 "blt 51b\n"
1685 "blt 50b\n"
1686 "b 31f\n"
1687 "30:\n"
1688 // Yes, all of the 4x4 block fits.
1689 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1690 "mov x3, x4\n"
1691 "st1 {v16.h}[0], [x3], #2\n"
1692 "st1 {v16.h}[1], [x3], #2\n"
1693 "st1 {v16.h}[2], [x3], #2\n"
1694 "st1 {v16.h}[3], [x3], #2\n"
1695 "31:\n"
1696
1697 "add %[dst_ptr], %[dst_ptr], #8\n"
1698
1699 RUY_MAKE_ZERO(v16)
1700 RUY_MAKE_ZERO(v17)
1701
1702 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1703
1704 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1705
1706 // Since the store type is the same as the accum type, no need for
1707 // downcast. There's also no need for clamp by min/max.
1708
1709 // Compute how much of the 4x4 block of destination 8bit values that
1710 // we have computed, fit in the destination matrix. Typically, all of
1711 // it fits, but when the destination matrix shape is not a multiple
1712 // of 4x4, there are some 4x4 blocks along the boundaries that do
1713 // not fit entirely.
1714 "sub w1, %w[dst_rows], %w[row]\n"
1715 "sub w2, %w[dst_cols], %w[col]\n"
1716 "mov w3, #4\n"
1717 "cmp w1, #4\n"
1718 // Compute w1 = how many rows of the 4x4 block fit
1719 "csel w1, w1, w3, le\n"
1720 "cmp w2, #4\n"
1721
1722 // Test if w1==4 i.e. if all of the 4x1 block fits.
1723 "cmp w1, w3\n"
1724 "ccmp w2, w3, 0, eq\n"
1725 "mov x4, %[dst_ptr]\n"
1726 // Yes, all of the 4x1 block fits, go to fast path.
1727 "beq 30f\n"
1728 // Not all of the 4x4 block fits.
1729 // Store to dst_tmp_buf
1730 "str q16, [%[dst_tmp_buf], #0]\n"
1731 // Slow loop copying from dst_tmp_buf to dst.
1732 "mov x3, %[dst_tmp_buf]\n"
1733 "mov w6, #0\n"
1734 "50:\n"
1735 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1736 "mov w5, #0\n"
1737 "51:\n"
1738 "ldr w7, [x3, x5, lsl #2]\n"
1739 "str w7, [x4, x5, lsl #2]\n"
1740 "add w5, w5, #1\n"
1741 "cmp w5, w1\n"
1742 "blt 51b\n"
1743 "b 31f\n"
1744 "30:\n"
1745 // Yes, all of the 4x4 block fits.
1746 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
1747 "mov x3, x4\n"
1748 "st1 {v16.s}[0], [x3], #4\n"
1749 "st1 {v16.s}[1], [x3], #4\n"
1750 "st1 {v16.s}[2], [x3], #4\n"
1751 "st1 {v16.s}[3], [x3], #4\n"
1752 "31:\n"
1753
1754 "add %[dst_ptr], %[dst_ptr], #16\n"
1755
1756 RUY_MAKE_ZERO(v16)
1757 RUY_MAKE_ZERO(v17)
1758 RUY_MAKE_ZERO(v18)
1759 RUY_MAKE_ZERO(v19)
1760
1761 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1762
1763 // For the next block: perform the first few multiply-adds on the data
1764 // that we have already loaded.
1765 "smull v8.8h, v0.8b, v4.8b\n"
1766 "smull v9.8h, v1.8b, v4.8b\n"
1767 "smull v10.8h, v2.8b, v4.8b\n"
1768 "smull v11.8h, v3.8b, v4.8b\n"
1769 "smlal2 v8.8h, v0.16b, v4.16b\n"
1770 "smlal2 v9.8h, v1.16b, v4.16b\n"
1771 "smlal2 v10.8h, v2.16b, v4.16b\n"
1772 "smlal2 v11.8h, v3.16b, v4.16b\n"
1773
1774 // Reload some params --- we had used x5 -- x7 for a few other things
1775 // since the last time we had loaded them.
1776 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1777 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1778 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1779
1780 // Move to the next block of the destination matrix, for the next iter
1781 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1782 // been updated earlier.
1783 // Have we reached the end row?
1784 "cmp %w[row], w7\n"
1785 "beq 20f\n" // yes, end row.
1786 // Not end row. Move to the next row.
1787 "add %w[row], %w[row], #4\n"
1788 "b 21f\n"
1789 "20:\n"
1790 // Was already at end row.
1791 "mov %w[row], w6\n" // Move back to first row.
1792 "add %w[col], %w[col], #4\n" // Move to the next column.
1793 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
1794 "mov %[dst_ptr], %[dst_col_ptr]\n"
1795 "21:\n"
1796
1797 // Main loop exit condition: have we hit the end column?
1798 "cmp %w[col], w8\n"
1799
1800 // w1 is the number of levels of depth that we have already loaded
1801 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1802 // above, this is currently 16.
1803 "mov w1, #16\n"
1804
1805 "ble 1b\n"
1806
1807 // clang-format on
1808
1809 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
1810 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
1811 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
1812 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
1813 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
1814 [dst_type_id] "r"(params.dst_type_id)
1815 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
1816 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1817 "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1818 }
1819
1820 // Variant of the above Kernel8bitNeon, tuned for A55-ish CPUs.
1821 // Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
1822 // the original Cortex-A55, since these are 64-bit and do not support dotprod.
1823 //
1824 // While this kernel does not have a direct equivalent in gemmlowp, it was
1825 // developed based on insights that David Mansell at ARM shared with their
1826 // contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
1827 // comments. Specifically, see this comment about tuning for Cortex-A53:
1828 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
Kernel8bitNeonA55ish(const KernelParams8bit<4,4> & params)1829 void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
1830 profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
1831
1832 CheckOffsetsInKernelParams8bit(params);
1833
1834 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1835 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1836 const std::int8_t* lhs_ptr = lhs_col_ptr;
1837 const std::int8_t* rhs_ptr = rhs_col_ptr;
1838 void* dst_col_ptr = params.dst_base_ptr;
1839 void* dst_ptr = dst_col_ptr;
1840 int row = params.start_row;
1841 int col = params.start_col;
1842
1843 // The asm kernel below has the following NEON register allocation:
1844 //
1845 // v16 -- v31 are int32 accumulators.
1846 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
1847 // v4 -- v7 from RHS:
1848 //
1849 // int8 RHS 16x4 block
1850 // /-----------------------------------------|
1851 // |v4.b[0] ... v7.b[0] |
1852 // | ... ... |
1853 // |v4.b[15] ... v7.b[15] |
1854 // \-----------------------------------------/
1855 // int8 LHS 4x16 block
1856 // /---------------------\ /-----------------------------------------|
1857 // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
1858 // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
1859 // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
1860 // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
1861 // \---------------------/ \-----------------------------------------/
1862 // int32 accumulators 4x4 block
1863 asm volatile(
1864 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
1865
1866 // clang-format off
1867
1868 // Load some parameters into registers.
1869 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1870 RUY_MAKE_ZERO(v16)
1871 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1872 RUY_MAKE_ZERO(v17)
1873 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1874 RUY_MAKE_ZERO(v18)
1875 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1876 RUY_MAKE_ZERO(v19)
1877 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1878 RUY_MAKE_ZERO(v20)
1879 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1880 RUY_MAKE_ZERO(v21)
1881 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1882 RUY_MAKE_ZERO(v22)
1883 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1884 RUY_MAKE_ZERO(v23)
1885
1886 // Load the first 64 bytes of LHS and RHS data.
1887 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
1888 RUY_MAKE_ZERO(v24)
1889 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
1890 RUY_MAKE_ZERO(v25)
1891 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
1892 RUY_MAKE_ZERO(v26)
1893 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
1894 RUY_MAKE_ZERO(v27)
1895 "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
1896 RUY_MAKE_ZERO(v28)
1897 "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
1898 RUY_MAKE_ZERO(v29)
1899 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
1900 RUY_MAKE_ZERO(v30)
1901 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
1902 RUY_MAKE_ZERO(v31)
1903
1904
1905 // w1 is the number of levels of depth that we have already loaded
1906 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1907 // above, this is currently 16.
1908 "mov w1, #16\n"
1909
1910 // Perform the first few multiply-adds on the data that we have already
1911 // loaded.
1912 "smull v8.8h, v0.8b, v4.8b\n"
1913 "smull v9.8h, v1.8b, v4.8b\n"
1914 "smull v10.8h, v2.8b, v4.8b\n"
1915 "smull v11.8h, v3.8b, v4.8b\n"
1916 "smull v12.8h, v0.8b, v5.8b\n"
1917 "smull v13.8h, v1.8b, v5.8b\n"
1918 "smull v14.8h, v2.8b, v5.8b\n"
1919 "smull v15.8h, v3.8b, v5.8b\n"
1920
1921 // Multiply-accumulate second-half, again into the same
1922 // 16bit local accumulator registers. This is where we
1923 // take advantage of having int8 instead of uint8 and therefore
1924 // being able to accumulate two products into int16.
1925 "smlal2 v8.8h, v0.16b, v4.16b\n"
1926 "smlal2 v9.8h, v1.16b, v4.16b\n"
1927 "smlal2 v10.8h, v2.16b, v4.16b\n"
1928 "smlal2 v11.8h, v3.16b, v4.16b\n"
1929 "smlal2 v12.8h, v0.16b, v5.16b\n"
1930 "smlal2 v13.8h, v1.16b, v5.16b\n"
1931 "smlal2 v14.8h, v2.16b, v5.16b\n"
1932 "smlal2 v15.8h, v3.16b, v5.16b\n"
1933
1934
1935 // Main loop of the whole GEMM, over rows and columns of the
1936 // destination matrix.
1937 "1:\n"
1938
1939 // Reminder - w1 is how many levels of depth we have already loaded
1940 // data for, w12 is the total depth.
1941 "cmp w1, w12\n"
1942 "beq 79f\n"
1943
1944 "2:\n"
1945
1946 // Some multiplications and 16-bit accumulation were already done above,
1947 // so we start right away in the middle.
1948 "sadalp v16.4s, v8.8h\n"
1949 "ldr d4, [%[rhs_ptr], #0]\n"
1950 "smull v8.8h, v0.8b, v6.8b\n"
1951 "ldr x7, [%[rhs_ptr], #8]\n"
1952 "sadalp v17.4s, v9.8h\n"
1953 "ldr d5, [%[rhs_ptr], #16]\n"
1954 "smull v9.8h, v1.8b, v6.8b\n"
1955 "ldr x8, [%[rhs_ptr], #24]\n"
1956 "sadalp v18.4s, v10.8h\n"
1957 "smull v10.8h, v2.8b, v6.8b\n"
1958 "sadalp v19.4s, v11.8h\n"
1959 "add %[lhs_ptr], %[lhs_ptr], #64\n"
1960 "smull v11.8h, v3.8b, v6.8b\n"
1961 "add %[rhs_ptr], %[rhs_ptr], #64\n"
1962 "sadalp v20.4s, v12.8h\n"
1963 // Each iteration of this loop advances by 16 levels of depth.
1964 "add w1, w1, #16\n"
1965 "smull v12.8h, v0.8b, v7.8b\n"
1966 // Loop termination condition
1967 "cmp w1, w12\n"
1968 "sadalp v21.4s, v13.8h\n"
1969 "ldr x3, [%[lhs_ptr], #-56]\n"
1970 "smull v13.8h, v1.8b, v7.8b\n"
1971 "ldr x4, [%[lhs_ptr], #-40]\n"
1972 "sadalp v22.4s, v14.8h\n"
1973 "ldr x5, [%[lhs_ptr], #-24]\n"
1974 "smull v14.8h, v2.8b, v7.8b\n"
1975 "ldr x6, [%[lhs_ptr], #-8]\n"
1976 "sadalp v23.4s, v15.8h\n"
1977 "smull v15.8h, v3.8b, v7.8b\n"
1978
1979 // Multiply-accumulate second-half, again into the same
1980 // 16bit local accumulator registers. This is where we
1981 // take advantage of having int8 instead of uint8 and therefore
1982 // being able to accumulate two products into int16.
1983 "smlal2 v8.8h, v0.16b, v6.16b\n"
1984 "smlal2 v9.8h, v1.16b, v6.16b\n"
1985 "smlal2 v10.8h, v2.16b, v6.16b\n"
1986 "ldr x9, [%[rhs_ptr], #-24]\n"
1987 "smlal2 v11.8h, v3.16b, v6.16b\n"
1988 "ldr d6, [%[rhs_ptr], #-32]\n"
1989 "smlal2 v12.8h, v0.16b, v7.16b\n"
1990 "ldr d0, [%[lhs_ptr], #-64]\n"
1991 "smlal2 v13.8h, v1.16b, v7.16b\n"
1992 "ldr d1, [%[lhs_ptr], #-48]\n"
1993 "smlal2 v14.8h, v2.16b, v7.16b\n"
1994 "ins v4.d[1], x7\n"
1995 "smlal2 v15.8h, v3.16b, v7.16b\n"
1996 "ins v5.d[1], x8\n"
1997
1998 "ldr d2, [%[lhs_ptr], #-32]\n"
1999 "ins v0.d[1], x3\n"
2000 "sadalp v24.4s, v8.8h\n"
2001 "ldr d3, [%[lhs_ptr], #-16]\n"
2002 "ins v1.d[1], x4\n"
2003 "smull v8.8h, v0.8b, v4.8b\n"
2004 "ins v2.d[1], x5\n"
2005 "sadalp v25.4s, v9.8h\n"
2006 "ins v3.d[1], x6\n"
2007 "smull v9.8h, v1.8b, v4.8b\n"
2008 "ldr d7, [%[rhs_ptr], #-16]\n"
2009 "sadalp v26.4s, v10.8h\n"
2010 "ldr x10, [%[rhs_ptr], #-8]\n"
2011 "smull v10.8h, v2.8b, v4.8b\n"
2012 "sadalp v27.4s, v11.8h\n"
2013 "smull v11.8h, v3.8b, v4.8b\n"
2014 "sadalp v28.4s, v12.8h\n"
2015 "smull v12.8h, v0.8b, v5.8b\n"
2016 "sadalp v29.4s, v13.8h\n"
2017 "smull v13.8h, v1.8b, v5.8b\n"
2018 "sadalp v30.4s, v14.8h\n"
2019 "smull v14.8h, v2.8b, v5.8b\n"
2020 "sadalp v31.4s, v15.8h\n"
2021 "smull v15.8h, v3.8b, v5.8b\n"
2022
2023 // Multiply-accumulate second-half, again into the same
2024 // 16bit local accumulator registers. This is where we
2025 // take advantage of having int8 instead of uint8 and therefore
2026 // being able to accumulate two products into int16.
2027 "smlal2 v8.8h, v0.16b, v4.16b\n"
2028 "smlal2 v9.8h, v1.16b, v4.16b\n"
2029 "smlal2 v10.8h, v2.16b, v4.16b\n"
2030 "smlal2 v11.8h, v3.16b, v4.16b\n"
2031
2032 "smlal2 v12.8h, v0.16b, v5.16b\n"
2033 "smlal2 v13.8h, v1.16b, v5.16b\n"
2034 "ins v6.d[1], x9\n"
2035 "smlal2 v14.8h, v2.16b, v5.16b\n"
2036 "ins v7.d[1], x10\n"
2037 "smlal2 v15.8h, v3.16b, v5.16b\n"
2038
2039 "blt 2b\n"
2040
2041 "79:\n"
2042
2043 "sadalp v16.4s, v8.8h\n"
2044 "smull v8.8h, v0.8b, v6.8b\n"
2045 "sadalp v17.4s, v9.8h\n"
2046 "smull v9.8h, v1.8b, v6.8b\n"
2047 "sadalp v18.4s, v10.8h\n"
2048 "smull v10.8h, v2.8b, v6.8b\n"
2049 "sadalp v19.4s, v11.8h\n"
2050 "smull v11.8h, v3.8b, v6.8b\n"
2051 "sadalp v20.4s, v12.8h\n"
2052 "smull v12.8h, v0.8b, v7.8b\n"
2053 "sadalp v21.4s, v13.8h\n"
2054 "smull v13.8h, v1.8b, v7.8b\n"
2055 "sadalp v22.4s, v14.8h\n"
2056 "smull v14.8h, v2.8b, v7.8b\n"
2057 "sadalp v23.4s, v15.8h\n"
2058 "smull v15.8h, v3.8b, v7.8b\n"
2059
2060 // Multiply-accumulate second-half, again into the same
2061 // 16bit local accumulator registers. This is where we
2062 // take advantage of having int8 instead of uint8 and therefore
2063 // being able to accumulate two products into int16.
2064 "smlal2 v8.8h, v0.16b, v6.16b\n"
2065 "smlal2 v9.8h, v1.16b, v6.16b\n"
2066 "smlal2 v10.8h, v2.16b, v6.16b\n"
2067 "smlal2 v11.8h, v3.16b, v6.16b\n"
2068
2069 "smlal2 v12.8h, v0.16b, v7.16b\n"
2070 "smlal2 v13.8h, v1.16b, v7.16b\n"
2071 "smlal2 v14.8h, v2.16b, v7.16b\n"
2072 "smlal2 v15.8h, v3.16b, v7.16b\n"
2073
2074 "sadalp v24.4s, v8.8h\n"
2075 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2076 "sadalp v25.4s, v9.8h\n"
2077 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2078 "sadalp v26.4s, v10.8h\n"
2079 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2080 "sadalp v27.4s, v11.8h\n"
2081 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2082 "sadalp v28.4s, v12.8h\n"
2083 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
2084 "sadalp v29.4s, v13.8h\n"
2085 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
2086 "sadalp v30.4s, v14.8h\n"
2087 "sadalp v31.4s, v15.8h\n"
2088
2089 // End of accumulation. The registers v16 -- v31 contain the final
2090 // int32 accumulator values of the current 4x4 destination block.
2091 // We now have to compute the final 8-bit values from these int32
2092 // accumulators, and advance to the next 4x4 block. We intertwine
2093 // these two aspects whenever possible for optimal pipelining, both
2094 // at the data flow level (prefetch data for next block as early as
2095 // possible) and instruction pipelining level (some of the next-block
2096 // work can dual-issue with some of the final work on the current
2097 // block).
2098
2099 // Reduce 32bit accumulators horizontally.
2100 "addp v16.4s, v16.4s, v17.4s\n"
2101 "addp v18.4s, v18.4s, v19.4s\n"
2102 "addp v20.4s, v20.4s, v21.4s\n"
2103 "addp v22.4s, v22.4s, v23.4s\n"
2104 "addp v24.4s, v24.4s, v25.4s\n"
2105 "addp v26.4s, v26.4s, v27.4s\n"
2106 "addp v28.4s, v28.4s, v29.4s\n"
2107 "addp v30.4s, v30.4s, v31.4s\n"
2108
2109 // Reduce 32bit accumulators horizontally, second pass
2110 // (each pass adds pairwise. we need to add 4-wise).
2111 "addp v16.4s, v16.4s, v18.4s\n"
2112 "addp v17.4s, v20.4s, v22.4s\n"
2113 "addp v18.4s, v24.4s, v26.4s\n"
2114 "addp v19.4s, v28.4s, v30.4s\n"
2115
2116 // Logic to advance to the next block in preparation for the next
2117 // iteration of the main loop. For now, we only want to compute
2118 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
2119 // not yet ready to update the values of row and col, as we still need
2120 // the current values for the rest of the work on the current block.
2121
2122 "cmp %w[row], w7\n" // Have we finished the last row?
2123 "bge 4f\n" // If finished last row, go to 4
2124 // Not finished last row: then advance to next row.
2125 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
2126 "b 5f\n"
2127 "4:\n" // Finished last row...
2128 "mov %[lhs_col_ptr], x5\n" // Go back to first row
2129 // Now we need to advance to the next column. If we already
2130 // finished the last column, then in principle we are done, however
2131 // we can't just return here, as we need to allow the end work of the
2132 // current block to complete. The good news is that at this point it
2133 // doesn't matter what data we load for the next column, since
2134 // we will exit from the main loop below before actually storing
2135 // anything computed from that data.
2136 "cmp %w[col], w8\n" // Have we finished the last column?
2137 "bge 5f\n" // If yes, just carry on without updating the column pointer.
2138 // Not finished last column: then advance to next column.
2139 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
2140 "5:\n"
2141
2142 // Set the LHS and RHS data pointers to the start of the columns just
2143 // computed.
2144 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
2145 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
2146
2147 // Load some parameters needed for the end work on current block.
2148 "mvni v8.4s, #0\n"
2149 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2150 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
2151 "ins v13.h[4], w4\n" // dst_zero_point
2152 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
2153 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
2154 "dup v9.4s, w3\n" // create prod_zp_depth_vec
2155
2156 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
2157
2158 // Determine the channel index.
2159 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
2160 "csel w3, %w[row], %w[col], eq\n"
2161
2162 // Offset the bias pointer as needed given the current row, col.
2163 "add x5, x1, x3, lsl #2\n"
2164
2165 // If there is no bias, use no offset, just address the passed zero
2166 // data.
2167 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
2168 "csel x1, x1, x5, eq\n"
2169
2170 // Load 4 bias values.
2171 "ld1 {v14.4s}, [x1]\n"
2172
2173 // Load the multiplier_fixedpoint values.
2174 "add x5, x4, x3, lsl #2\n"
2175 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
2176 "csel x4, x4, x5, eq\n"
2177 "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
2178
2179 // Now that we know what LHS and RHS data the next iteration of the
2180 // main loop will need to load, we start loading the first 32 bytes of
2181 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
2182 // in the rest of the work on the current block.
2183
2184 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
2185 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2186 "add v14.4s, v14.4s, v9.4s\n"
2187 "ldr d0, [%[lhs_ptr], #0]\n"
2188
2189 // Perform the bias-addition (per the above, we have just folded into
2190 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
2191 // Jump based on channel dimension.
2192 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
2193 "bne 6f\n"
2194 // Case where channels are rows
2195
2196 "add v16.4s, v16.4s, v14.4s\n"
2197 "ldr d1, [%[lhs_ptr], #16]\n"
2198 "add v17.4s, v17.4s, v14.4s\n"
2199 "ldr d2, [%[lhs_ptr], #32]\n"
2200 "add v18.4s, v18.4s, v14.4s\n"
2201 "ldr d3, [%[lhs_ptr], #48]\n"
2202 "add v19.4s, v19.4s, v14.4s\n"
2203 "ldr d4, [%[rhs_ptr], #0]\n"
2204 "ldr d5, [%[rhs_ptr], #16]\n"
2205 "ldr d6, [%[rhs_ptr], #32]\n"
2206 "ldr d7, [%[rhs_ptr], #48]\n"
2207
2208 "b 7f\n"
2209
2210 "6:\n"
2211 // Case where channels are columns
2212 "dup v20.4s, v14.s[0]\n"
2213 "ldr d1, [%[lhs_ptr], #16]\n"
2214 "dup v21.4s, v14.s[1]\n"
2215 "ldr d2, [%[lhs_ptr], #32]\n"
2216 "dup v22.4s, v14.s[2]\n"
2217 "ldr d3, [%[lhs_ptr], #48]\n"
2218 "dup v23.4s, v14.s[3]\n"
2219 "ldr d4, [%[rhs_ptr], #0]\n"
2220 "add v16.4s, v16.4s, v20.4s\n"
2221 "ldr d5, [%[rhs_ptr], #16]\n"
2222 "add v17.4s, v17.4s, v21.4s\n"
2223 "ldr d6, [%[rhs_ptr], #32]\n"
2224 "add v18.4s, v18.4s, v22.4s\n"
2225 "ldr d7, [%[rhs_ptr], #48]\n"
2226 "add v19.4s, v19.4s, v23.4s\n"
2227 "7:\n"
2228
2229 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
2230 "beq 401f\n"
2231 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
2232 "add x3, x3, %x[col], lsl #2\n"
2233 "ld1 {v14.4s}, [x3]\n"
2234 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
2235 "dup v10.4s, w5\n" // create lhs_zero_point_vec
2236 // Subtract rhs_sums * lhs_zero_point, per
2237 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2238 "mls v16.4s, v10.4s, v14.s[0]\n"
2239 "mls v17.4s, v10.4s, v14.s[1]\n"
2240 "mls v18.4s, v10.4s, v14.s[2]\n"
2241 "mls v19.4s, v10.4s, v14.s[3]\n"
2242 "401:\n"
2243
2244 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
2245 "beq 402f\n"
2246 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
2247 "add x2, x2, %x[row], lsl #2\n"
2248 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
2249 // Load 4 lhs_sums values.
2250 "ld1 {v11.4s}, [x2]\n"
2251 "ins v13.s[1], w5\n" // rhs_zero_point
2252 // Compute lhs_sums * rhs_zero_point.
2253 "mul v11.4s, v11.4s, v13.s[1]\n"
2254 // Subtract lhs_sums * rhs_zero_point, per
2255 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
2256 "sub v16.4s, v16.4s, v11.4s\n"
2257 "sub v17.4s, v17.4s, v11.4s\n"
2258 "sub v18.4s, v18.4s, v11.4s\n"
2259 "sub v19.4s, v19.4s, v11.4s\n"
2260
2261 // If the destination is int32, it means the user asks for the raw
2262 // accumulators, no need for us to downquantize the value.
2263 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
2264 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
2265
2266 "402:\n"
2267
2268 // At this point we have computed the final int32 values. Now we
2269 // start down-quantizing them to obtain the final 8bit values from them.
2270
2271 // As part of this down-quantization, our int32 values will be
2272 // multiplied by a multiplier that has a fixed-point component and an
2273 // exponent component.
2274
2275 // Determine the channel index.
2276 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
2277 "csel w3, %w[row], %w[col], eq\n"
2278
2279 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
2280 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
2281 "add x5, x1, x3, lsl #2\n"
2282 "csel x1, x1, x5, eq\n"
2283
2284 "ld1 {v14.4s}, [x1]\n"
2285
2286 "smin v11.4s, v8.4s, v14.4s\n"
2287 "ldr x1, [%[lhs_ptr], #8]\n"
2288 "sub v12.4s, v14.4s, v11.4s\n"
2289
2290 // Jump based on channel dimension.
2291 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
2292 "bne 8f\n"
2293 // Case where channels are rows
2294
2295
2296 // Apply the positive exponent part of the multiplier.
2297 "sshl v16.4s, v16.4s, v12.4s\n"
2298 "ldr x2, [%[lhs_ptr], #24]\n"
2299 "sshl v17.4s, v17.4s, v12.4s\n"
2300 "ldr x3, [%[lhs_ptr], #40]\n"
2301 "sshl v18.4s, v18.4s, v12.4s\n"
2302 "ldr x4, [%[lhs_ptr], #56]\n"
2303 "sshl v19.4s, v19.4s, v12.4s\n"
2304
2305
2306 // Apply the fixed-point part of the multiplier.
2307 "ins v0.d[1], x1\n"
2308 "ldr x1, [%[rhs_ptr], #8]\n"
2309 "sqdmulh v16.4s, v16.4s, v15.4s\n"
2310 "ins v1.d[1], x2\n"
2311 "ldr x2, [%[rhs_ptr], #24]\n"
2312 "sqdmulh v17.4s, v17.4s, v15.4s\n"
2313 "ins v2.d[1], x3\n"
2314 "ldr x3, [%[rhs_ptr], #40]\n"
2315 "sqdmulh v18.4s, v18.4s, v15.4s\n"
2316 "ins v3.d[1], x4\n"
2317 "ldr x4, [%[rhs_ptr], #56]\n"
2318 "sqdmulh v19.4s, v19.4s, v15.4s\n"
2319
2320 // Apply the negative exponent part of the multiplier.
2321 "srshl v16.4s, v16.4s, v11.4s\n"
2322 "srshl v17.4s, v17.4s, v11.4s\n"
2323 "srshl v18.4s, v18.4s, v11.4s\n"
2324 "srshl v19.4s, v19.4s, v11.4s\n"
2325
2326 "b 9f\n"
2327
2328 "8:\n"
2329 // Case where channels are columns
2330
2331 // Apply the positive exponent part of the multiplier.
2332 "dup v20.4s, v12.s[0]\n"
2333 "ldr x2, [%[lhs_ptr], #24]\n"
2334 "ldr x3, [%[lhs_ptr], #40]\n"
2335 "dup v21.4s, v12.s[1]\n"
2336 "ldr x4, [%[lhs_ptr], #56]\n"
2337 "dup v22.4s, v12.s[2]\n"
2338 "ins v0.d[1], x1\n"
2339 "dup v23.4s, v12.s[3]\n"
2340 "ldr x1, [%[rhs_ptr], #8]\n"
2341 "sshl v16.4s, v16.4s, v20.4s\n"
2342 "ins v1.d[1], x2\n"
2343 "sshl v17.4s, v17.4s, v21.4s\n"
2344 "ldr x2, [%[rhs_ptr], #24]\n"
2345 "sshl v18.4s, v18.4s, v22.4s\n"
2346 "ins v2.d[1], x3\n"
2347 "sshl v19.4s, v19.4s, v23.4s\n"
2348 "ldr x3, [%[rhs_ptr], #40]\n"
2349
2350 // Apply the fixed-point part of the multiplier.
2351 "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
2352 "ins v3.d[1], x4\n"
2353 "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
2354 "ldr x4, [%[rhs_ptr], #56]\n"
2355 "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
2356 "dup v20.4s, v11.s[0]\n"
2357 "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
2358
2359 // Apply the negative exponent part of the multiplier.
2360 "dup v21.4s, v11.s[1]\n"
2361 "srshl v16.4s, v16.4s, v20.4s\n"
2362 "dup v22.4s, v11.s[2]\n"
2363 "srshl v17.4s, v17.4s, v21.4s\n"
2364 "dup v23.4s, v11.s[3]\n"
2365 "srshl v18.4s, v18.4s, v22.4s\n"
2366 "srshl v19.4s, v19.4s, v23.4s\n"
2367
2368 "9:\n"
2369
2370 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
2371 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
2372 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
2373 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
2374
2375 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
2376
2377 "ins v4.d[1], x1\n"
2378 "sqxtn v16.4h, v16.4s\n"
2379 "ins v5.d[1], x2\n"
2380 "sqxtn2 v16.8h, v17.4s\n"
2381 "ins v6.d[1], x3\n"
2382 "sqxtn v17.4h, v18.4s\n"
2383 "ins v7.d[1], x4\n"
2384 RUY_MAKE_ZERO(v18)
2385 "sqxtn2 v17.8h, v19.4s\n"
2386
2387 // At this point, v18 -- v31 aren't used anymore for the current block,
2388 // so we can start clearing these accumulators for the next block
2389 // (next iteration of the main loop).
2390 RUY_MAKE_ZERO(v19)
2391
2392 // Add the destination zero point
2393 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2394 "dup v14.8h, v13.h[4]\n"
2395 RUY_MAKE_ZERO(v20)
2396 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2397 "add v16.8h, v16.8h, v14.8h\n"
2398 RUY_MAKE_ZERO(v21)
2399 "add v17.8h, v17.8h, v14.8h\n"
2400 RUY_MAKE_ZERO(v22)
2401
2402 // Cast-and-saturate from int16 to uint8
2403 "sqxtun v16.8b, v16.8h\n"
2404 RUY_MAKE_ZERO(v23)
2405 "sqxtun2 v16.16b, v17.8h\n"
2406 RUY_MAKE_ZERO(v24)
2407
2408 // Load the clamp_min, clamp_max bounds
2409 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2410 RUY_MAKE_ZERO(v25)
2411 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2412 RUY_MAKE_ZERO(v26)
2413 "dup v14.16b, w2\n" // clamp_min
2414 RUY_MAKE_ZERO(v27)
2415 "dup v15.16b, w3\n" // clamp_max
2416 RUY_MAKE_ZERO(v28)
2417
2418 // Apply the clamp_min bound
2419 "umax v16.16b, v16.16b, v14.16b\n"
2420 RUY_MAKE_ZERO(v29)
2421 // Apply the clamp_max bound
2422 "umin v16.16b, v16.16b, v15.16b\n"
2423 RUY_MAKE_ZERO(v30)
2424
2425 // Compute how much of the 4x4 block of destination 8bit values that
2426 // we have computed, fit in the destination matrix. Typically, all of
2427 // it fits, but when the destination matrix shape is not a multiple
2428 // of 4x4, there are some 4x4 blocks along the boundaries that do
2429 // not fit entirely.
2430 "sub w1, %w[dst_rows], %w[row]\n"
2431 RUY_MAKE_ZERO(v31)
2432 "sub w2, %w[dst_cols], %w[col]\n"
2433 "mov w3, #4\n"
2434 "cmp w1, #4\n"
2435 // Compute w1 = how many rows of the 4x4 block fit
2436 "csel w1, w1, w3, le\n"
2437 "cmp w2, #4\n"
2438 // Compute w2 = how many cols of the 4x4 block fit
2439 "csel w2, w2, w3, le\n"
2440
2441 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2442 "cmp w1, w3\n"
2443 "ccmp w2, w3, 0, eq\n"
2444 "mov x4, %[dst_ptr]\n"
2445 // Yes, all of the 4x4 block fits, go to fast path.
2446 "beq 30f\n"
2447 // Not all of the 4x4 block fits.
2448 // Store to dst_tmp_buf
2449 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
2450 // Slow loop copying from dst_tmp_buf to dst.
2451 "mov x3, %[dst_tmp_buf]\n"
2452 "mov w6, #0\n"
2453 "50:\n"
2454 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2455 "mov w5, #0\n"
2456 "51:\n"
2457 "ldrb w7, [x3, w5, uxtw]\n"
2458 "strb w7, [x4, w5, uxtw]\n"
2459 "add w5, w5, #1\n"
2460 "cmp w5, w1\n"
2461 "blt 51b\n"
2462 "add w6, w6, #1\n"
2463 "add x3, x3, #4\n"
2464 "add x4, x4, x11\n"
2465 "cmp w6, w2\n"
2466 "blt 50b\n"
2467 "b 31f\n"
2468 "30:\n"
2469 // Yes, all of the 4x4 block fits.
2470 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2471 "mov x3, x4\n"
2472 "st1 {v16.b}[0], [x3], #1\n"
2473 "add x4, x4, x11\n"
2474 "st1 {v16.b}[1], [x3], #1\n"
2475 "st1 {v16.b}[2], [x3], #1\n"
2476 "st1 {v16.b}[3], [x3], #1\n"
2477 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2478 "mov x3, x4\n"
2479 "st1 {v16.b}[4], [x3], #1\n"
2480 "add x4, x4, x11\n"
2481 "st1 {v16.b}[5], [x3], #1\n"
2482 "st1 {v16.b}[6], [x3], #1\n"
2483 "st1 {v16.b}[7], [x3], #1\n"
2484 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2485 "mov x3, x4\n"
2486 "st1 {v16.b}[8], [x3], #1\n"
2487 "add x4, x4, x11\n"
2488 "st1 {v16.b}[9], [x3], #1\n"
2489 "st1 {v16.b}[10], [x3], #1\n"
2490 "st1 {v16.b}[11], [x3], #1\n"
2491 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2492 "mov x3, x4\n"
2493 "st1 {v16.b}[12], [x3], #1\n"
2494 "add x4, x4, x11\n"
2495 "st1 {v16.b}[13], [x3], #1\n"
2496 "st1 {v16.b}[14], [x3], #1\n"
2497 "st1 {v16.b}[15], [x3], #1\n"
2498 "31:\n"
2499
2500 "add %[dst_ptr], %[dst_ptr], #4\n"
2501
2502 RUY_MAKE_ZERO(v16)
2503 RUY_MAKE_ZERO(v17)
2504
2505 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2506
2507 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2508
2509 "ins v4.d[1], x1\n"
2510 "sqxtn v16.4h, v16.4s\n"
2511 "ins v5.d[1], x2\n"
2512 "sqxtn2 v16.8h, v17.4s\n"
2513 "ins v6.d[1], x3\n"
2514 "sqxtn v17.4h, v18.4s\n"
2515 "ins v7.d[1], x4\n"
2516 RUY_MAKE_ZERO(v18)
2517 "sqxtn2 v17.8h, v19.4s\n"
2518
2519 // At this point, v18 -- v31 aren't used anymore for the current block,
2520 // so we can start clearing these accumulators for the next block
2521 // (next iteration of the main loop).
2522 RUY_MAKE_ZERO(v19)
2523
2524 // Add the destination zero point
2525 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2526 "dup v14.8h, v13.h[4]\n"
2527 RUY_MAKE_ZERO(v20)
2528 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2529 "add v16.8h, v16.8h, v14.8h\n"
2530 RUY_MAKE_ZERO(v21)
2531 "add v17.8h, v17.8h, v14.8h\n"
2532 RUY_MAKE_ZERO(v22)
2533
2534 // Cast-and-saturate from int16 to uint8
2535 "sqxtn v16.8b, v16.8h\n"
2536 RUY_MAKE_ZERO(v23)
2537 "sqxtn2 v16.16b, v17.8h\n"
2538 RUY_MAKE_ZERO(v24)
2539
2540 // Load the clamp_min, clamp_max bounds
2541 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2542 RUY_MAKE_ZERO(v25)
2543 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2544 RUY_MAKE_ZERO(v26)
2545 "dup v14.16b, w2\n" // clamp_min
2546 RUY_MAKE_ZERO(v27)
2547 "dup v15.16b, w3\n" // clamp_max
2548 RUY_MAKE_ZERO(v28)
2549
2550 // Apply the clamp_min bound
2551 "smax v16.16b, v16.16b, v14.16b\n"
2552 RUY_MAKE_ZERO(v29)
2553 // Apply the clamp_max bound
2554 "smin v16.16b, v16.16b, v15.16b\n"
2555 RUY_MAKE_ZERO(v30)
2556
2557 // Compute how much of the 4x4 block of destination 8bit values that
2558 // we have computed, fit in the destination matrix. Typically, all of
2559 // it fits, but when the destination matrix shape is not a multiple
2560 // of 4x4, there are some 4x4 blocks along the boundaries that do
2561 // not fit entirely.
2562 "sub w1, %w[dst_rows], %w[row]\n"
2563 RUY_MAKE_ZERO(v31)
2564 "sub w2, %w[dst_cols], %w[col]\n"
2565 "mov w3, #4\n"
2566 "cmp w1, #4\n"
2567 // Compute w1 = how many rows of the 4x4 block fit
2568 "csel w1, w1, w3, le\n"
2569 "cmp w2, #4\n"
2570 // Compute w2 = how many cols of the 4x4 block fit
2571 "csel w2, w2, w3, le\n"
2572
2573 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2574 "cmp w1, w3\n"
2575 "ccmp w2, w3, 0, eq\n"
2576 "mov x4, %[dst_ptr]\n"
2577 // Yes, all of the 4x4 block fits, go to fast path.
2578 "beq 30f\n"
2579 // Not all of the 4x4 block fits.
2580 // Store to dst_tmp_buf
2581 "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
2582 // Slow loop copying from dst_tmp_buf to dst.
2583 "mov x3, %[dst_tmp_buf]\n"
2584 "mov w6, #0\n"
2585 "50:\n"
2586 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2587 "mov w5, #0\n"
2588 "51:\n"
2589 "ldrb w7, [x3, w5, uxtw]\n"
2590 "strb w7, [x4, w5, uxtw]\n"
2591 "add w5, w5, #1\n"
2592 "cmp w5, w1\n"
2593 "blt 51b\n"
2594 "add w6, w6, #1\n"
2595 "add x3, x3, #4\n"
2596 "add x4, x4, x11\n"
2597 "cmp w6, w2\n"
2598 "blt 50b\n"
2599 "b 31f\n"
2600 "30:\n"
2601 // Yes, all of the 4x4 block fits.
2602 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2603 "mov x3, x4\n"
2604 "st1 {v16.b}[0], [x3], #1\n"
2605 "add x4, x4, x11\n"
2606 "st1 {v16.b}[1], [x3], #1\n"
2607 "st1 {v16.b}[2], [x3], #1\n"
2608 "st1 {v16.b}[3], [x3], #1\n"
2609 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2610 "mov x3, x4\n"
2611 "st1 {v16.b}[4], [x3], #1\n"
2612 "add x4, x4, x11\n"
2613 "st1 {v16.b}[5], [x3], #1\n"
2614 "st1 {v16.b}[6], [x3], #1\n"
2615 "st1 {v16.b}[7], [x3], #1\n"
2616 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2617 "mov x3, x4\n"
2618 "st1 {v16.b}[8], [x3], #1\n"
2619 "add x4, x4, x11\n"
2620 "st1 {v16.b}[9], [x3], #1\n"
2621 "st1 {v16.b}[10], [x3], #1\n"
2622 "st1 {v16.b}[11], [x3], #1\n"
2623 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2624 "mov x3, x4\n"
2625 "st1 {v16.b}[12], [x3], #1\n"
2626 "add x4, x4, x11\n"
2627 "st1 {v16.b}[13], [x3], #1\n"
2628 "st1 {v16.b}[14], [x3], #1\n"
2629 "st1 {v16.b}[15], [x3], #1\n"
2630 "31:\n"
2631
2632 "add %[dst_ptr], %[dst_ptr], #4\n"
2633
2634 RUY_MAKE_ZERO(v16)
2635 RUY_MAKE_ZERO(v17)
2636
2637 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2638
2639 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2640
2641 // Add the destination zero point
2642 "dup v14.4h, v13.h[4]\n"
2643 "saddw v16.4s, v16.4s, v14.4h\n"
2644 "saddw v17.4s, v17.4s, v14.4h\n"
2645 "saddw v18.4s, v18.4s, v14.4h\n"
2646 "saddw v19.4s, v19.4s, v14.4h\n"
2647
2648 // Cast-and-saturate from int32 to int16
2649 "ins v4.d[1], x1\n"
2650 "sqxtn v16.4h, v16.4s\n"
2651 "ins v5.d[1], x2\n"
2652 "sqxtn2 v16.8h, v17.4s\n"
2653 "ins v6.d[1], x3\n"
2654 "sqxtn v17.4h, v18.4s\n"
2655 "ins v7.d[1], x4\n"
2656 RUY_MAKE_ZERO(v18)
2657 "sqxtn2 v17.8h, v19.4s\n"
2658
2659 // At this point, v18 -- v31 aren't used anymore for the current block,
2660 // so we can start clearing these accumulators for the next block
2661 // (next iteration of the main loop).
2662 RUY_MAKE_ZERO(v19)
2663
2664 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2665 RUY_MAKE_ZERO(v20)
2666 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2667 RUY_MAKE_ZERO(v21)
2668 RUY_MAKE_ZERO(v22)
2669
2670 RUY_MAKE_ZERO(v23)
2671 RUY_MAKE_ZERO(v24)
2672
2673 // Load the clamp_min, clamp_max bounds
2674 "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2675 RUY_MAKE_ZERO(v25)
2676 "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2677 RUY_MAKE_ZERO(v26)
2678 "dup v14.8h, w2\n" // clamp_min
2679 RUY_MAKE_ZERO(v27)
2680 "dup v15.8h, w3\n" // clamp_max
2681 RUY_MAKE_ZERO(v28)
2682
2683 // Apply the clamp_min bound
2684 "smax v16.8h, v16.8h, v14.8h\n"
2685 "smax v17.8h, v17.8h, v14.8h\n"
2686 RUY_MAKE_ZERO(v29)
2687 // Apply the clamp_max bound
2688 "smin v16.8h, v16.8h, v15.8h\n"
2689 "smin v17.8h, v17.8h, v15.8h\n"
2690 RUY_MAKE_ZERO(v30)
2691
2692 // Compute how much of the 4x4 block of destination 8bit values that
2693 // we have computed, fit in the destination matrix. Typically, all of
2694 // it fits, but when the destination matrix shape is not a multiple
2695 // of 4x4, there are some 4x4 blocks along the boundaries that do
2696 // not fit entirely.
2697 "sub w1, %w[dst_rows], %w[row]\n"
2698 RUY_MAKE_ZERO(v31)
2699 "sub w2, %w[dst_cols], %w[col]\n"
2700 "mov w3, #4\n"
2701 "cmp w1, #4\n"
2702 // Compute w1 = how many rows of the 4x4 block fit
2703 "csel w1, w1, w3, le\n"
2704 "cmp w2, #4\n"
2705 // Compute w2 = how many cols of the 4x4 block fit
2706 "csel w2, w2, w3, le\n"
2707
2708 // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
2709 "cmp w1, w3\n"
2710 "ccmp w2, w3, 0, eq\n"
2711 "mov x4, %[dst_ptr]\n"
2712 // Yes, all of the 4x4 block fits, go to fast path.
2713 "beq 30f\n"
2714 // Not all of the 4x4 block fits.
2715 // Store to dst_tmp_buf
2716 "str q16, [%[dst_tmp_buf], #0]\n"
2717 "str q17, [%[dst_tmp_buf], #16]\n"
2718 // Slow loop copying from dst_tmp_buf to dst.
2719 "mov x3, %[dst_tmp_buf]\n"
2720 "mov w6, #0\n"
2721 "50:\n"
2722 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2723 "mov w5, #0\n"
2724 "51:\n"
2725 "ldrh w7, [x3, x5, lsl #1]\n"
2726 "strh w7, [x4, x5, lsl #1]\n"
2727 "add w5, w5, #1\n"
2728 "cmp w5, w1\n"
2729 "blt 51b\n"
2730 "add w6, w6, #1\n"
2731 "add x3, x3, #8\n"
2732 "add x4, x4, x11\n"
2733 "cmp w6, w2\n"
2734 "blt 50b\n"
2735 "b 31f\n"
2736 "30:\n"
2737 // Yes, all of the 4x4 block fits.
2738 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2739 "mov x3, x4\n"
2740 "st1 {v16.h}[0], [x3], #2\n"
2741 "add x4, x4, x11\n"
2742 "st1 {v16.h}[1], [x3], #2\n"
2743 "st1 {v16.h}[2], [x3], #2\n"
2744 "st1 {v16.h}[3], [x3], #2\n"
2745 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2746 "mov x3, x4\n"
2747 "st1 {v16.h}[4], [x3], #2\n"
2748 "add x4, x4, x11\n"
2749 "st1 {v16.h}[5], [x3], #2\n"
2750 "st1 {v16.h}[6], [x3], #2\n"
2751 "st1 {v16.h}[7], [x3], #2\n"
2752 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2753 "mov x3, x4\n"
2754 "st1 {v17.h}[0], [x3], #2\n"
2755 "add x4, x4, x11\n"
2756 "st1 {v17.h}[1], [x3], #2\n"
2757 "st1 {v17.h}[2], [x3], #2\n"
2758 "st1 {v17.h}[3], [x3], #2\n"
2759 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2760 "mov x3, x4\n"
2761 "st1 {v17.h}[4], [x3], #2\n"
2762 "add x4, x4, x11\n"
2763 "st1 {v17.h}[5], [x3], #2\n"
2764 "st1 {v17.h}[6], [x3], #2\n"
2765 "st1 {v17.h}[7], [x3], #2\n"
2766 "31:\n"
2767
2768 "add %[dst_ptr], %[dst_ptr], #8\n"
2769
2770 RUY_MAKE_ZERO(v16)
2771 RUY_MAKE_ZERO(v17)
2772
2773 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2774
2775 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2776
2777 "ldr x1, [%[lhs_ptr], #8]\n"
2778 "ldr x2, [%[lhs_ptr], #24]\n"
2779 "ldr x3, [%[lhs_ptr], #40]\n"
2780 "ldr x4, [%[lhs_ptr], #56]\n"
2781
2782 "ins v0.d[1], x1\n"
2783 "ldr x1, [%[rhs_ptr], #8]\n"
2784 "ins v1.d[1], x2\n"
2785 "ldr x2, [%[rhs_ptr], #24]\n"
2786 "ins v2.d[1], x3\n"
2787 "ldr x3, [%[rhs_ptr], #40]\n"
2788 "ins v3.d[1], x4\n"
2789 "ldr x4, [%[rhs_ptr], #56]\n"
2790 "ins v4.d[1], x1\n"
2791 "ins v5.d[1], x2\n"
2792 "ins v6.d[1], x3\n"
2793 "ins v7.d[1], x4\n"
2794
2795 // Since the store type is the same as the accum type, no need for
2796 // downcast. There's also no need for clamp by min/max.
2797
2798 // At this point, v20 -- v31 aren't used anymore for the current block,
2799 // so we can start clearing these accumulators for the next block
2800 // (next iteration of the main loop).
2801
2802 RUY_MAKE_ZERO(v20)
2803 "add %[lhs_ptr], %[lhs_ptr], #64\n"
2804 RUY_MAKE_ZERO(v21)
2805 "add %[rhs_ptr], %[rhs_ptr], #64\n"
2806 RUY_MAKE_ZERO(v22)
2807
2808 RUY_MAKE_ZERO(v23)
2809 RUY_MAKE_ZERO(v24)
2810 RUY_MAKE_ZERO(v25)
2811 RUY_MAKE_ZERO(v26)
2812 RUY_MAKE_ZERO(v27)
2813 RUY_MAKE_ZERO(v28)
2814 RUY_MAKE_ZERO(v29)
2815 RUY_MAKE_ZERO(v30)
2816
2817 // Compute how much of the 4x4 block of destination 8bit values that
2818 // we have computed, fit in the destination matrix. Typically, all of
2819 // it fits, but when the destination matrix shape is not a multiple
2820 // of 4x4, there are some 4x4 blocks along the boundaries that do
2821 // not fit entirely.
2822 "sub w1, %w[dst_rows], %w[row]\n"
2823 RUY_MAKE_ZERO(v31)
2824 "sub w2, %w[dst_cols], %w[col]\n"
2825 "mov w3, #4\n"
2826 "cmp w1, #4\n"
2827 // Compute w1 = how many rows of the 4x4 block fit
2828 "csel w1, w1, w3, le\n"
2829 "cmp w2, #4\n"
2830 // Compute w2 = how many cols of the 4x4 block fit
2831 "csel w2, w2, w3, le\n"
2832
2833 // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
2834 "cmp w1, w3\n"
2835 "ccmp w2, w3, 0, eq\n"
2836 "mov x4, %[dst_ptr]\n"
2837 // Yes, all of the 4x4 block fits, go to fast path.
2838 "beq 30f\n"
2839 // Not all of the 4x4 block fits.
2840 // Store to dst_tmp_buf
2841 "str q16, [%[dst_tmp_buf], #0]\n"
2842 "str q17, [%[dst_tmp_buf], #16]\n"
2843 "str q18, [%[dst_tmp_buf], #32]\n"
2844 "str q19, [%[dst_tmp_buf], #48]\n"
2845 // Slow loop copying from dst_tmp_buf to dst.
2846 "mov x3, %[dst_tmp_buf]\n"
2847 "mov w6, #0\n"
2848 "50:\n"
2849 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2850 "mov w5, #0\n"
2851 "51:\n"
2852 "ldr w7, [x3, x5, lsl #2]\n"
2853 "str w7, [x4, x5, lsl #2]\n"
2854 "add w5, w5, #1\n"
2855 "cmp w5, w1\n"
2856 "blt 51b\n"
2857 "add w6, w6, #1\n"
2858 "add x3, x3, #16\n"
2859 "add x4, x4, x11\n"
2860 "cmp w6, w2\n"
2861 "blt 50b\n"
2862 "b 31f\n"
2863 "30:\n"
2864 // Yes, all of the 4x4 block fits.
2865 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2866 "mov x3, x4\n"
2867 "st1 {v16.s}[0], [x3], #4\n"
2868 "add x4, x4, x11\n"
2869 "st1 {v16.s}[1], [x3], #4\n"
2870 "st1 {v16.s}[2], [x3], #4\n"
2871 "st1 {v16.s}[3], [x3], #4\n"
2872 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2873 "mov x3, x4\n"
2874 "st1 {v17.s}[0], [x3], #4\n"
2875 "add x4, x4, x11\n"
2876 "st1 {v17.s}[1], [x3], #4\n"
2877 "st1 {v17.s}[2], [x3], #4\n"
2878 "st1 {v17.s}[3], [x3], #4\n"
2879 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2880 "mov x3, x4\n"
2881 "st1 {v18.s}[0], [x3], #4\n"
2882 "add x4, x4, x11\n"
2883 "st1 {v18.s}[1], [x3], #4\n"
2884 "st1 {v18.s}[2], [x3], #4\n"
2885 "st1 {v18.s}[3], [x3], #4\n"
2886 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
2887 "mov x3, x4\n"
2888 "st1 {v19.s}[0], [x3], #4\n"
2889 "add x4, x4, x11\n"
2890 "st1 {v19.s}[1], [x3], #4\n"
2891 "st1 {v19.s}[2], [x3], #4\n"
2892 "st1 {v19.s}[3], [x3], #4\n"
2893 "31:\n"
2894
2895 "add %[dst_ptr], %[dst_ptr], #16\n"
2896
2897 RUY_MAKE_ZERO(v16)
2898 RUY_MAKE_ZERO(v17)
2899 RUY_MAKE_ZERO(v18)
2900 RUY_MAKE_ZERO(v19)
2901
2902 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2903
2904 // For the next block: perform the first few multiply-adds on the data
2905 // that we have already loaded.
2906 "smull v8.8h, v0.8b, v4.8b\n"
2907 "smull v9.8h, v1.8b, v4.8b\n"
2908 "smull v10.8h, v2.8b, v4.8b\n"
2909 // Reload some params --- we had used x5 -- x7 for a few other things
2910 // since the last time we had loaded them.
2911 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2912 "smull v11.8h, v3.8b, v4.8b\n"
2913 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2914 "smull v12.8h, v0.8b, v5.8b\n"
2915 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2916 "smull v13.8h, v1.8b, v5.8b\n"
2917 "smull v14.8h, v2.8b, v5.8b\n"
2918 "smull v15.8h, v3.8b, v5.8b\n"
2919 // Move to the next block of the destination matrix, for the next iter
2920 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
2921 // been updated earlier.
2922 // Have we reached the end row?
2923 "cmp %w[row], w7\n"
2924 "smlal2 v8.8h, v0.16b, v4.16b\n"
2925 "smlal2 v9.8h, v1.16b, v4.16b\n"
2926 "smlal2 v10.8h, v2.16b, v4.16b\n"
2927 "smlal2 v11.8h, v3.16b, v4.16b\n"
2928 "smlal2 v12.8h, v0.16b, v5.16b\n"
2929 "smlal2 v13.8h, v1.16b, v5.16b\n"
2930 "smlal2 v14.8h, v2.16b, v5.16b\n"
2931 "smlal2 v15.8h, v3.16b, v5.16b\n"
2932
2933
2934 "beq 20f\n" // yes, end row.
2935 // Not end row. Move to the next row.
2936 "add %w[row], %w[row], #4\n"
2937 "b 21f\n"
2938 "20:\n"
2939 // Was already at end row.
2940 "mov %w[row], w6\n" // Move back to first row.
2941 "add %w[col], %w[col], #4\n" // Move to the next column.
2942 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
2943 "mov %[dst_ptr], %[dst_col_ptr]\n"
2944 "21:\n"
2945
2946 // Main loop exit condition: have we hit the end column?
2947 "cmp %w[col], w8\n"
2948
2949 // w1 is the number of levels of depth that we have already loaded
2950 // LHS and RHS data for. Corresponding to the initial ld1 instructions
2951 // above, this is currently 4.
2952 "mov w1, #16\n"
2953
2954 "ble 1b\n"
2955
2956 // clang-format on
2957
2958 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
2959 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
2960 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
2961 : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows),
2962 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
2963 [dst_type_id] "r"(params.dst_type_id)
2964 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
2965 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
2966 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
2967 "v26", "v27", "v28", "v29", "v30", "v31");
2968 }
2969
2970 // Kernel taking advantage of the optional dotprod instruction.
2971 // This is very similar to (and directly inspired by) this gemmlowp kernel
2972 // which was contributed by David Mansell at ARM:
2973 // NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
2974 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
2975 //
2976 // Besides the ruy-ification, the main difference here is that we use a 8x8
2977 // instead of 12x8 width, so as to stick to power-of-two widths. This slightly
2978 // narrower kernel layout is still wide enough to achieve high performance
2979 // although we haven't actually performed a real comparison to know exactly
2980 // how this compares to ARM's aforementioned kernel.
2981 //
2982 // Relevant target CPUs for this kernel include ARM Cortex-A76,
2983 // since these are 64-bit, out-of-order and with dotprod support.
Kernel8bitNeonDotprod(const KernelParams8bit<8,8> & params)2984 void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
2985 profiler::ScopeLabel label("Kernel (kNeonDotprod)");
2986
2987 CheckOffsetsInKernelParams8bit(params);
2988
2989 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
2990 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
2991 const std::int8_t* lhs_ptr = lhs_col_ptr;
2992 const std::int8_t* rhs_ptr = rhs_col_ptr;
2993 void* dst_col_ptr = params.dst_base_ptr;
2994 void* dst_ptr = dst_col_ptr;
2995 int row = params.start_row;
2996 int col = params.start_col;
2997
2998 // The asm kernel below has the following NEON register allocation:
2999 //
3000 // v16 -- v31 are int32 accumulators.
3001 // During accumulation, v0 -- v15 are used to load int8 data from LHS and
3002 // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
3003 // v3 are used to load a 4x8 block of RHS, like this:
3004 //
3005 // int8 RHS 4x8 block
3006 // /-----------------------------------------|
3007 // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
3008 // | ... ... |
3009 // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
3010 // \-----------------------------------------/
3011 // int8 LHS 8x4 block
3012 // /---------------------\ /-----------------------------------------|
3013 // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
3014 // | ... ... | | ... ... |
3015 // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
3016 // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
3017 // | ... ... | | ... ... |
3018 // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
3019 // \---------------------/ \-----------------------------------------/
3020 // int32 accumulators 8x8 block
3021 //
3022 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
3023 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
3024 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
3025 //
3026 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
3027 // unused, and v8 -- v15 are used for loading parameters used for the
3028 // post-accumulation part of the kernel.
3029 asm volatile(
3030 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
3031
3032 // clang-format off
3033
3034 // Load some parameters into registers.
3035 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
3036 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
3037 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
3038 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
3039 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
3040 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
3041 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
3042 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
3043
3044 // Load the first 32 bytes of LHS and RHS data.
3045 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3046 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3047 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3048 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3049
3050 // Clear accumulators.
3051 RUY_MAKE_ZERO(v16)
3052 RUY_MAKE_ZERO(v17)
3053 RUY_MAKE_ZERO(v18)
3054 RUY_MAKE_ZERO(v19)
3055 RUY_MAKE_ZERO(v20)
3056 RUY_MAKE_ZERO(v21)
3057 RUY_MAKE_ZERO(v22)
3058 RUY_MAKE_ZERO(v23)
3059 RUY_MAKE_ZERO(v24)
3060 RUY_MAKE_ZERO(v25)
3061 RUY_MAKE_ZERO(v26)
3062 RUY_MAKE_ZERO(v27)
3063 RUY_MAKE_ZERO(v28)
3064 RUY_MAKE_ZERO(v29)
3065 RUY_MAKE_ZERO(v30)
3066 RUY_MAKE_ZERO(v31)
3067
3068 // w1 is the number of levels of depth that we have already loaded
3069 // LHS and RHS data for. Corresponding to the initial ld1 instructions
3070 // above, this is currently 4.
3071 "mov w1, #4\n"
3072
3073 // Perform the first few multiply-adds on the data that we have already
3074 // loaded.
3075 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3076 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3077 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3078 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3079
3080 // Main loop of the whole GEMM, over rows and columns of the
3081 // destination matrix.
3082 "1:\n"
3083
3084 // Optional, maximally-streaming, partial-unrolling (4x unrolled)
3085 // optimization of the kernel inner loop (over depth). For more
3086 // comments, see the non-unrolled loop below after the #endif.
3087 #if RUY_OPT(MAX_STREAMING)
3088 "cmp w12, #32\n"
3089 "blt 78f\n"
3090
3091 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
3092 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
3093 "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
3094 "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
3095 "ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
3096 "ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
3097 "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
3098 "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
3099 "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
3100 "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
3101 "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
3102 "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
3103 "mov w1, #16\n"
3104
3105 "and w3, w12, #-16\n"
3106 "81:\n"
3107 "add w1, w1, #16\n"
3108
3109 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3110 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3111 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3112 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3113 "ldr q0, [%[lhs_ptr], #0]\n"
3114 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3115 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3116 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3117 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3118 "ldr q2, [%[rhs_ptr], #0]\n"
3119 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3120 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3121 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3122 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3123 "ldr q1, [%[lhs_ptr], #16]\n"
3124
3125 ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
3126 ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
3127 "ldr q3, [%[rhs_ptr], #16]\n"
3128 ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
3129 ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
3130 ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
3131 ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
3132 ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
3133 ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
3134 ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
3135 ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
3136 ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
3137 ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
3138 "ldr q5, [%[lhs_ptr], #48]\n"
3139 ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
3140 ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
3141 "ldr q7, [%[rhs_ptr], #48]\n"
3142 ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
3143 ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
3144 "ldr q4, [%[lhs_ptr], #32]\n"
3145
3146 ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
3147 ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
3148 "ldr q6, [%[rhs_ptr], #32]\n"
3149 ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
3150 ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
3151 ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
3152 ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
3153 ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
3154 ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
3155 ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
3156 ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
3157 ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
3158 ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
3159 "ldr q9, [%[lhs_ptr], #80]\n"
3160 ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
3161 ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
3162 "ldr q11, [%[rhs_ptr], #80]\n"
3163 ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
3164 ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
3165 "ldr q8, [%[lhs_ptr], #64]\n"
3166
3167 ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
3168 ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
3169 "ldr q10, [%[rhs_ptr], #64]\n"
3170 ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
3171 ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
3172 "add %[lhs_ptr], %[lhs_ptr], #128\n"
3173 ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
3174 ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
3175 "add %[rhs_ptr], %[rhs_ptr], #128\n"
3176 ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
3177 ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
3178 ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
3179 ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
3180 "cmp w1, w3\n"
3181 ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
3182 ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
3183 "ldr q13, [%[lhs_ptr], #-16]\n"
3184 ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
3185 ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
3186 "ldr q15, [%[rhs_ptr], #-16]\n"
3187 ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
3188 ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
3189 "ldr q12, [%[lhs_ptr], #-32]\n"
3190
3191 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3192 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3193 "ldr q14, [%[rhs_ptr], #-32]\n"
3194 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3195 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3196
3197 "blt 81b\n"
3198
3199 ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
3200 ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
3201 ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
3202 ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
3203 ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
3204 ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
3205 ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
3206 ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
3207 ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
3208 ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
3209 ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
3210 ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
3211 ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
3212 ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
3213 ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
3214 ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
3215
3216 ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
3217 ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
3218 ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
3219 ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
3220 ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
3221 ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
3222 ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
3223 ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
3224 ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
3225 ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
3226 ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
3227 ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
3228 ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
3229 ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
3230 ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
3231 ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
3232
3233 ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
3234 ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
3235 ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
3236 ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
3237 ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
3238 ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
3239 ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
3240 ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
3241 ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
3242 ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
3243 ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
3244 ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
3245 ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
3246 ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
3247 ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
3248 ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
3249
3250 "78:\n"
3251
3252 #endif // #if RUY_OPT(MAX_STREAMING)
3253
3254 // Ordinary kernel inner loop (over depth), the simpler loop that the
3255 // above was an equivalent 4x-partially-unrolled version of.
3256
3257 // Reminder - w1 is how many levels of depth we have already loaded
3258 // data for, w12 is the total depth.
3259 "cmp w1, w12\n"
3260 "beq 79f\n"
3261
3262 "2:\n"
3263
3264 // Because of the data that we have already loaded, we can start the
3265 // loop body right away with some multiply-adds.
3266 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3267 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3268 // Each iteration of this loop advances by 4 levels of depth.
3269 "add w1, w1, #4\n"
3270 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3271 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3272 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3273 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3274 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3275 // Loop termination condition.
3276 "cmp w1, w12\n"
3277 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3278 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3279 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3280 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3281 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3282 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3283 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3284 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3285 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3286 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3287 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3288 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3289 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3290
3291 "blt 2b\n"
3292
3293 "79:\n"
3294 // End of the inner loop on depth. Now perform the remaining
3295 // multiply-adds of the last 4 levels of depth, for which the LHS
3296 // and RHS data is already loaded.
3297
3298 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
3299 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
3300 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
3301 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
3302 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
3303 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
3304 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
3305 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
3306 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
3307 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
3308 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
3309 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
3310
3311 // End of accumulation. The registers v16 -- v31 contain the final
3312 // int32 accumulator values of the current 8x8 destination block.
3313 // We now have to compute the final 8-bit values from these int32
3314 // accumulators, and advance to the next 8x8 block. We intertwine
3315 // these two aspects whenever possible for optimal pipelining, both
3316 // at the data flow level (prefetch data for next block as early as
3317 // possible) and instruction pipelining level (some of the next-block
3318 // work can dual-issue with some of the final work on the current
3319 // block).
3320
3321 // Logic to advance to the next block in preparation for the next
3322 // iteration of the main loop. For now, we only want to compute
3323 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
3324 // not yet ready to update the values of row and col, as we still need
3325 // the current values for the rest of the work on the current block.
3326
3327 "cmp %w[row], w7\n" // Have we finished the last row?
3328 "bge 4f\n" // If finished last row, go to 4
3329 // Not finished last row: then advance to next row.
3330 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
3331 "b 5f\n"
3332 "4:\n" // Finished last row...
3333 "mov %[lhs_col_ptr], x5\n" // Go back to first row
3334 // Now we need to advance to the next column. If we already
3335 // finished the last column, then in principle we are done, however
3336 // we can't just return here, as we need to allow the end work of the
3337 // current block to complete. The good news is that at this point it
3338 // doesn't matter what data we load for the next column, since
3339 // we will exit from the main loop below before actually storing
3340 // anything computed from that data.
3341 "cmp %w[col], w8\n" // Have we finished the last column?
3342 "bge 5f\n" // If yes, just carry on without updating the column pointer.
3343 // Not finished last column: then advance to next column.
3344 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
3345 "5:\n"
3346
3347 // Set the LHS and RHS data pointers to the start of the columns just
3348 // computed.
3349 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
3350 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
3351
3352 // Load some parameters needed for the end work on current block.
3353 "mvni v8.4s, #0\n"
3354 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
3355 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
3356 "dup v9.4s, w3\n" // create prod_zp_depth_vec
3357
3358 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
3359 // Determine the channel index.
3360 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
3361 "csel w3, %w[row], %w[col], eq\n"
3362
3363 // Offset the bias pointer as needed given the current row, col.
3364 "add x5, x1, x3, lsl #2\n"
3365
3366 // If there is no bias, use no offset, just address the passed zero
3367 // data.
3368 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
3369 "csel x1, x1, x5, eq\n"
3370
3371 // Load 8 bias values.
3372 "ld1 {v14.4s}, [x1], #16\n"
3373 "ld1 {v15.4s}, [x1]\n"
3374
3375 // Now that we know what LHS and RHS data the next iteration of the
3376 // main loop will need to load, we start loading the first 32 bytes of
3377 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
3378 // in the rest of the work on the current block.
3379 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
3380 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
3381 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
3382 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
3383
3384 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
3385 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3386 "add v14.4s, v14.4s, v9.4s\n"
3387 "add v15.4s, v15.4s, v9.4s\n"
3388
3389 // Perform the bias-addition (per the above, we have just folded into
3390 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
3391 // Jump based on channel dimension.
3392 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
3393 "bne 6f\n"
3394 // Case where channels are rows
3395 "add v16.4s, v16.4s, v14.4s\n"
3396 "add v17.4s, v17.4s, v15.4s\n"
3397 "add v18.4s, v18.4s, v14.4s\n"
3398 "add v19.4s, v19.4s, v15.4s\n"
3399 "add v20.4s, v20.4s, v14.4s\n"
3400 "add v21.4s, v21.4s, v15.4s\n"
3401 "add v22.4s, v22.4s, v14.4s\n"
3402 "add v23.4s, v23.4s, v15.4s\n"
3403 "add v24.4s, v24.4s, v14.4s\n"
3404 "add v25.4s, v25.4s, v15.4s\n"
3405 "add v26.4s, v26.4s, v14.4s\n"
3406 "add v27.4s, v27.4s, v15.4s\n"
3407 "add v28.4s, v28.4s, v14.4s\n"
3408 "add v29.4s, v29.4s, v15.4s\n"
3409 "add v30.4s, v30.4s, v14.4s\n"
3410 "add v31.4s, v31.4s, v15.4s\n"
3411 "b 7f\n"
3412
3413 "6:\n"
3414 // Case where channels are columns
3415 "dup v10.4s, v14.s[0]\n"
3416 "dup v11.4s, v14.s[1]\n"
3417 "dup v12.4s, v14.s[2]\n"
3418 "dup v13.4s, v14.s[3]\n"
3419 "add v16.4s, v16.4s, v10.4s\n"
3420 "add v17.4s, v17.4s, v10.4s\n"
3421 "add v18.4s, v18.4s, v11.4s\n"
3422 "add v19.4s, v19.4s, v11.4s\n"
3423 "add v20.4s, v20.4s, v12.4s\n"
3424 "add v21.4s, v21.4s, v12.4s\n"
3425 "add v22.4s, v22.4s, v13.4s\n"
3426 "add v23.4s, v23.4s, v13.4s\n"
3427 "dup v10.4s, v15.s[0]\n"
3428 "dup v11.4s, v15.s[1]\n"
3429 "dup v12.4s, v15.s[2]\n"
3430 "dup v13.4s, v15.s[3]\n"
3431 "add v24.4s, v24.4s, v10.4s\n"
3432 "add v25.4s, v25.4s, v10.4s\n"
3433 "add v26.4s, v26.4s, v11.4s\n"
3434 "add v27.4s, v27.4s, v11.4s\n"
3435 "add v28.4s, v28.4s, v12.4s\n"
3436 "add v29.4s, v29.4s, v12.4s\n"
3437 "add v30.4s, v30.4s, v13.4s\n"
3438 "add v31.4s, v31.4s, v13.4s\n"
3439 "7:\n"
3440
3441 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
3442 "beq 401f\n"
3443 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
3444 "add x3, x3, %x[col], lsl #2\n"
3445 "ld1 {v14.4s}, [x3], #16\n"
3446 "ld1 {v15.4s}, [x3]\n"
3447 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
3448 "dup v10.4s, w5\n" // create lhs_zero_point_vec
3449 // Subtract rhs_sums * lhs_zero_point, per
3450 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3451 "mls v16.4s, v10.4s, v14.s[0]\n"
3452 "mls v17.4s, v10.4s, v14.s[0]\n"
3453 "mls v18.4s, v10.4s, v14.s[1]\n"
3454 "mls v19.4s, v10.4s, v14.s[1]\n"
3455 "mls v20.4s, v10.4s, v14.s[2]\n"
3456 "mls v21.4s, v10.4s, v14.s[2]\n"
3457 "mls v22.4s, v10.4s, v14.s[3]\n"
3458 "mls v23.4s, v10.4s, v14.s[3]\n"
3459 "mls v24.4s, v10.4s, v15.s[0]\n"
3460 "mls v25.4s, v10.4s, v15.s[0]\n"
3461 "mls v26.4s, v10.4s, v15.s[1]\n"
3462 "mls v27.4s, v10.4s, v15.s[1]\n"
3463 "mls v28.4s, v10.4s, v15.s[2]\n"
3464 "mls v29.4s, v10.4s, v15.s[2]\n"
3465 "mls v30.4s, v10.4s, v15.s[3]\n"
3466 "mls v31.4s, v10.4s, v15.s[3]\n"
3467 "401:\n"
3468
3469 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
3470 "beq 402f\n"
3471 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
3472 "add x2, x2, %x[row], lsl #2\n"
3473 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
3474 // Load 4 lhs_sums values.
3475 "ld1 {v11.4s}, [x2], #16\n"
3476 "ld1 {v12.4s}, [x2]\n"
3477 "ins v13.s[1], w5\n" // rhs_zero_point
3478 // Compute lhs_sums * rhs_zero_point.
3479 "mul v11.4s, v11.4s, v13.s[1]\n"
3480 "mul v12.4s, v12.4s, v13.s[1]\n"
3481 // Subtract lhs_sums * rhs_zero_point, per
3482 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
3483 "sub v16.4s, v16.4s, v11.4s\n"
3484 "sub v17.4s, v17.4s, v12.4s\n"
3485 "sub v18.4s, v18.4s, v11.4s\n"
3486 "sub v19.4s, v19.4s, v12.4s\n"
3487 "sub v20.4s, v20.4s, v11.4s\n"
3488 "sub v21.4s, v21.4s, v12.4s\n"
3489 "sub v22.4s, v22.4s, v11.4s\n"
3490 "sub v23.4s, v23.4s, v12.4s\n"
3491 "sub v24.4s, v24.4s, v11.4s\n"
3492 "sub v25.4s, v25.4s, v12.4s\n"
3493 "sub v26.4s, v26.4s, v11.4s\n"
3494 "sub v27.4s, v27.4s, v12.4s\n"
3495 "sub v28.4s, v28.4s, v11.4s\n"
3496 "sub v29.4s, v29.4s, v12.4s\n"
3497 "sub v30.4s, v30.4s, v11.4s\n"
3498 "sub v31.4s, v31.4s, v12.4s\n"
3499
3500 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
3501 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
3502
3503 "402:\n"
3504
3505 // At this point we have computed the final int32 values. Now we
3506 // start down-quantizing them to obtain the final 8bit values from them.
3507
3508 // As part of this down-quantization, our int32 values will be
3509 // multiplied by a multiplier that has a fixed-point component and an
3510 // exponent component.
3511
3512 //Load the exponent part of the multiplier.
3513 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
3514 // Determine the channel index.
3515 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
3516 "csel w3, %w[row], %w[col], eq\n"
3517 // Compute the multiplier_exponent pointer
3518 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
3519 "add x5, x1, x3, lsl #2\n"
3520 "csel x1, x1, x5, eq\n"
3521 // Load multiplier_exponent
3522 "ldr q9, [x1]\n"
3523 "ldr q10, [x1, #16]\n"
3524 // Separate positive and negative exponents
3525 "smin v11.4s, v8.4s, v9.4s\n"
3526 "smin v12.4s, v8.4s, v10.4s\n"
3527 "sub v9.4s, v9.4s, v11.4s\n"
3528 "sub v10.4s, v10.4s, v12.4s\n"
3529
3530 // Compute the multiplier_fixedpoint pointer
3531 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
3532 "add x5, x4, x3, lsl #2\n"
3533 "csel x4, x4, x5, eq\n"
3534 // Load multiplier_fixedpoint
3535 "ldr q14, [x4]\n"
3536 "ldr q15, [x4, #16]\n"
3537
3538 // Jump based on channel dimension.
3539 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
3540 "bne 8f\n"
3541 // Case where channels are rows
3542
3543 // Apply the positive exponent part of the multiplier.
3544 "sshl v16.4s, v16.4s, v9.4s\n"
3545 "sshl v17.4s, v17.4s, v10.4s\n"
3546 "sshl v18.4s, v18.4s, v9.4s\n"
3547 "sshl v19.4s, v19.4s, v10.4s\n"
3548 "sshl v20.4s, v20.4s, v9.4s\n"
3549 "sshl v21.4s, v21.4s, v10.4s\n"
3550 "sshl v22.4s, v22.4s, v9.4s\n"
3551 "sshl v23.4s, v23.4s, v10.4s\n"
3552 "sshl v24.4s, v24.4s, v9.4s\n"
3553 "sshl v25.4s, v25.4s, v10.4s\n"
3554 "sshl v26.4s, v26.4s, v9.4s\n"
3555 "sshl v27.4s, v27.4s, v10.4s\n"
3556 "sshl v28.4s, v28.4s, v9.4s\n"
3557 "sshl v29.4s, v29.4s, v10.4s\n"
3558 "sshl v30.4s, v30.4s, v9.4s\n"
3559 "sshl v31.4s, v31.4s, v10.4s\n"
3560 "10:\n"
3561
3562 // Apply the fixed-point part of the multiplier.
3563 "sqdmulh v16.4s, v16.4s, v14.4s\n"
3564 "sqdmulh v17.4s, v17.4s, v15.4s\n"
3565 "sqdmulh v18.4s, v18.4s, v14.4s\n"
3566 "sqdmulh v19.4s, v19.4s, v15.4s\n"
3567 "sqdmulh v20.4s, v20.4s, v14.4s\n"
3568 "sqdmulh v21.4s, v21.4s, v15.4s\n"
3569 "sqdmulh v22.4s, v22.4s, v14.4s\n"
3570 "sqdmulh v23.4s, v23.4s, v15.4s\n"
3571 "sqdmulh v24.4s, v24.4s, v14.4s\n"
3572 "sqdmulh v25.4s, v25.4s, v15.4s\n"
3573 "sqdmulh v26.4s, v26.4s, v14.4s\n"
3574 "sqdmulh v27.4s, v27.4s, v15.4s\n"
3575 "sqdmulh v28.4s, v28.4s, v14.4s\n"
3576 "sqdmulh v29.4s, v29.4s, v15.4s\n"
3577 "sqdmulh v30.4s, v30.4s, v14.4s\n"
3578 "sqdmulh v31.4s, v31.4s, v15.4s\n"
3579
3580 // Apply the negative exponent part of the multiplier.
3581 "srshl v16.4s, v16.4s, v11.4s\n"
3582 "srshl v17.4s, v17.4s, v12.4s\n"
3583 "srshl v18.4s, v18.4s, v11.4s\n"
3584 "srshl v19.4s, v19.4s, v12.4s\n"
3585 "srshl v20.4s, v20.4s, v11.4s\n"
3586 "srshl v21.4s, v21.4s, v12.4s\n"
3587 "srshl v22.4s, v22.4s, v11.4s\n"
3588 "srshl v23.4s, v23.4s, v12.4s\n"
3589 "srshl v24.4s, v24.4s, v11.4s\n"
3590 "srshl v25.4s, v25.4s, v12.4s\n"
3591 "srshl v26.4s, v26.4s, v11.4s\n"
3592 "srshl v27.4s, v27.4s, v12.4s\n"
3593 "srshl v28.4s, v28.4s, v11.4s\n"
3594 "srshl v29.4s, v29.4s, v12.4s\n"
3595 "srshl v30.4s, v30.4s, v11.4s\n"
3596 "srshl v31.4s, v31.4s, v12.4s\n"
3597 "b 9f\n"
3598
3599 "8:\n"
3600 // Case where channels are columns
3601
3602 // Apply the positive exponent part of the multiplier.
3603 "dup v4.4s, v9.s[0]\n"
3604 "dup v5.4s, v9.s[1]\n"
3605 "dup v6.4s, v9.s[2]\n"
3606 "dup v7.4s, v9.s[3]\n"
3607 "sshl v16.4s, v16.4s, v4.4s\n"
3608 "sshl v17.4s, v17.4s, v4.4s\n"
3609 "sshl v18.4s, v18.4s, v5.4s\n"
3610 "sshl v19.4s, v19.4s, v5.4s\n"
3611 "sshl v20.4s, v20.4s, v6.4s\n"
3612 "sshl v21.4s, v21.4s, v6.4s\n"
3613 "sshl v22.4s, v22.4s, v7.4s\n"
3614 "sshl v23.4s, v23.4s, v7.4s\n"
3615 "dup v4.4s, v10.s[0]\n"
3616 "dup v5.4s, v10.s[1]\n"
3617 "dup v6.4s, v10.s[2]\n"
3618 "dup v7.4s, v10.s[3]\n"
3619 "sshl v24.4s, v24.4s, v4.4s\n"
3620 "sshl v25.4s, v25.4s, v4.4s\n"
3621 "sshl v26.4s, v26.4s, v5.4s\n"
3622 "sshl v27.4s, v27.4s, v5.4s\n"
3623 "sshl v28.4s, v28.4s, v6.4s\n"
3624 "sshl v29.4s, v29.4s, v6.4s\n"
3625 "sshl v30.4s, v30.4s, v7.4s\n"
3626 "sshl v31.4s, v31.4s, v7.4s\n"
3627 "11:\n"
3628
3629 // Apply the fixed-point part of the multiplier.
3630 "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
3631 "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
3632 "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
3633 "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
3634 "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
3635 "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
3636 "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
3637 "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
3638 "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
3639 "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
3640 "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
3641 "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
3642 "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
3643 "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
3644 "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
3645 "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
3646
3647 // Apply the negative exponent part of the multiplier.
3648 "dup v4.4s, v11.s[0]\n"
3649 "dup v5.4s, v11.s[1]\n"
3650 "dup v6.4s, v11.s[2]\n"
3651 "dup v7.4s, v11.s[3]\n"
3652 "srshl v16.4s, v16.4s, v4.4s\n"
3653 "srshl v17.4s, v17.4s, v4.4s\n"
3654 "srshl v18.4s, v18.4s, v5.4s\n"
3655 "srshl v19.4s, v19.4s, v5.4s\n"
3656 "srshl v20.4s, v20.4s, v6.4s\n"
3657 "srshl v21.4s, v21.4s, v6.4s\n"
3658 "srshl v22.4s, v22.4s, v7.4s\n"
3659 "srshl v23.4s, v23.4s, v7.4s\n"
3660 "dup v4.4s, v12.s[0]\n"
3661 "dup v5.4s, v12.s[1]\n"
3662 "dup v6.4s, v12.s[2]\n"
3663 "dup v7.4s, v12.s[3]\n"
3664 "srshl v24.4s, v24.4s, v4.4s\n"
3665 "srshl v25.4s, v25.4s, v4.4s\n"
3666 "srshl v26.4s, v26.4s, v5.4s\n"
3667 "srshl v27.4s, v27.4s, v5.4s\n"
3668 "srshl v28.4s, v28.4s, v6.4s\n"
3669 "srshl v29.4s, v29.4s, v6.4s\n"
3670 "srshl v30.4s, v30.4s, v7.4s\n"
3671 "srshl v31.4s, v31.4s, v7.4s\n"
3672 "9:\n"
3673
3674 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
3675 "ins v13.h[4], w4\n" // dst_zero_point
3676
3677 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
3678 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
3679 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
3680 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
3681
3682 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
3683
3684 // Cast-and-saturate from int32 to int16
3685 "sqxtn v16.4h, v16.4s\n"
3686 "sqxtn2 v16.8h, v17.4s\n"
3687 "sqxtn v17.4h, v18.4s\n"
3688 "sqxtn2 v17.8h, v19.4s\n"
3689 "sqxtn v18.4h, v20.4s\n"
3690 "sqxtn2 v18.8h, v21.4s\n"
3691 "sqxtn v19.4h, v22.4s\n"
3692 "sqxtn2 v19.8h, v23.4s\n"
3693 "sqxtn v20.4h, v24.4s\n"
3694 "sqxtn2 v20.8h, v25.4s\n"
3695 "sqxtn v21.4h, v26.4s\n"
3696 "sqxtn2 v21.8h, v27.4s\n"
3697 "sqxtn v22.4h, v28.4s\n"
3698 "sqxtn2 v22.8h, v29.4s\n"
3699 "sqxtn v23.4h, v30.4s\n"
3700 "sqxtn2 v23.8h, v31.4s\n"
3701
3702 // At this point, v24 -- v31 aren't used anymore for the current block,
3703 // so we can start clearing these accumulators for the next block
3704 // (next iteration of the main loop).
3705 RUY_MAKE_ZERO(v24)
3706 RUY_MAKE_ZERO(v25)
3707 RUY_MAKE_ZERO(v26)
3708 RUY_MAKE_ZERO(v27)
3709 RUY_MAKE_ZERO(v28)
3710 RUY_MAKE_ZERO(v29)
3711 RUY_MAKE_ZERO(v30)
3712 RUY_MAKE_ZERO(v31)
3713
3714 // Add the destination zero point
3715 "dup v14.8h, v13.h[4]\n"
3716 "add v16.8h, v16.8h, v14.8h\n"
3717 "add v17.8h, v17.8h, v14.8h\n"
3718 "add v18.8h, v18.8h, v14.8h\n"
3719 "add v19.8h, v19.8h, v14.8h\n"
3720 "add v20.8h, v20.8h, v14.8h\n"
3721 "add v21.8h, v21.8h, v14.8h\n"
3722 "add v22.8h, v22.8h, v14.8h\n"
3723 "add v23.8h, v23.8h, v14.8h\n"
3724
3725 // Cast-and-saturate from int16 to uint8
3726 "sqxtun v16.8b, v16.8h\n"
3727 "sqxtun2 v16.16b, v17.8h\n"
3728 "sqxtun v17.8b, v18.8h\n"
3729 "sqxtun2 v17.16b, v19.8h\n"
3730 "sqxtun v18.8b, v20.8h\n"
3731 "sqxtun2 v18.16b, v21.8h\n"
3732 "sqxtun v19.8b, v22.8h\n"
3733 "sqxtun2 v19.16b, v23.8h\n"
3734
3735 // Load the clamp_min, clamp_max bounds
3736 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
3737 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
3738 "dup v14.16b, w2\n" // clamp_min
3739 "dup v15.16b, w3\n" // clamp_max
3740
3741 // Apply the clamp_min bound
3742 "umax v16.16b, v16.16b, v14.16b\n"
3743 "umax v17.16b, v17.16b, v14.16b\n"
3744 "umax v18.16b, v18.16b, v14.16b\n"
3745 "umax v19.16b, v19.16b, v14.16b\n"
3746
3747 // Apply the clamp_max bound
3748 "umin v16.16b, v16.16b, v15.16b\n"
3749 "umin v17.16b, v17.16b, v15.16b\n"
3750 "umin v18.16b, v18.16b, v15.16b\n"
3751 "umin v19.16b, v19.16b, v15.16b\n"
3752
3753 // Make it so that all of the final 8bit values are stored in the
3754 // first 64bits of 128bit NEON registers, so they can be stored
3755 // by 64bit st1 store instructions with byte alignment.
3756 "dup d20, v16.d[1]\n"
3757 "dup d21, v17.d[1]\n"
3758 "dup d22, v18.d[1]\n"
3759 "dup d23, v19.d[1]\n"
3760
3761 // Compute how much of the 8x8 block of destination 8bit values that
3762 // we have computed, fit in the destination matrix. Typically, all of
3763 // it fits, but when the destination matrix shape is not a multiple
3764 // of 8x8, there are some 8x8 blocks along the boundaries that do
3765 // not fit entirely.
3766 "sub w1, %w[dst_rows], %w[row]\n"
3767 "sub w2, %w[dst_cols], %w[col]\n"
3768 "mov w3, #8\n"
3769 "cmp w1, #8\n"
3770 // Compute w1 = how many rows of the 8x8 block fit
3771 "csel w1, w1, w3, le\n"
3772 "cmp w2, #8\n"
3773 // Compute w2 = how many cols of the 8x8 block fit
3774 "csel w2, w2, w3, le\n"
3775
3776 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
3777 "cmp w1, w3\n"
3778 "ccmp w2, w3, 0, eq\n"
3779 // Yes, all of the 8x8 block fits, go to fast path.
3780 "beq 30f\n"
3781 // Not all of the 8x8 block fits.
3782 // Set (x3 address, x4 stride) to write to dst_tmp_buf
3783 "mov x3, %[dst_tmp_buf]\n"
3784 "mov x4, #8\n"
3785 "b 31f\n"
3786 "30:\n"
3787 // Yes, all of the 8x8 block fits.
3788 // Set (x3 address, x4 stride) to write directly to destination matrix.
3789 "mov x3, %[dst_ptr]\n"
3790 "mov x4, x11\n"
3791 "31:\n"
3792
3793 // Write our 8bit values to the destination described by
3794 // (x3 address, x4 stride).
3795 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3796 "st1 {v16.8b}, [x3], x4\n"
3797 RUY_MAKE_ZERO(v16)
3798 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3799 "st1 {v20.8b}, [x3], x4\n"
3800 RUY_MAKE_ZERO(v20)
3801 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3802 "st1 {v17.8b}, [x3], x4\n"
3803 RUY_MAKE_ZERO(v17)
3804 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3805 "st1 {v21.8b}, [x3], x4\n"
3806 RUY_MAKE_ZERO(v21)
3807 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3808 "st1 {v18.8b}, [x3], x4\n"
3809 RUY_MAKE_ZERO(v18)
3810 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3811 "st1 {v22.8b}, [x3], x4\n"
3812 RUY_MAKE_ZERO(v22)
3813 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3814 "st1 {v19.8b}, [x3], x4\n"
3815 RUY_MAKE_ZERO(v19)
3816 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3817 "st1 {v23.8b}, [x3], x4\n"
3818 RUY_MAKE_ZERO(v23)
3819
3820 // For the next block: perform the first few multiply-adds on the data
3821 // that we have already loaded.
3822 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3823 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3824 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
3825 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
3826
3827 // If all of the 8x8 block fits, we just finished writing it to the
3828 // destination, so we skip the next part.
3829 "beq 41f\n"
3830 // Not all of the 8x8 block fits in the destination matrix. We just
3831 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
3832 // it to copy into the destination matrix the part that fits.
3833 "mov x3, %[dst_tmp_buf]\n"
3834 "mov x4, %[dst_ptr]\n"
3835 "mov w6, #0\n"
3836 "50:\n"
3837 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
3838 "mov w5, #0\n"
3839 "51:\n"
3840 "ldrb w7, [x3, w5, uxtw]\n"
3841 "strb w7, [x4, w5, uxtw]\n"
3842 "add w5, w5, #1\n"
3843 "cmp w5, w1\n"
3844 "blt 51b\n"
3845 "add w6, w6, #1\n"
3846 "add x3, x3, #8\n"
3847 "add x4, x4, x11\n"
3848 "cmp w6, w2\n"
3849 "blt 50b\n"
3850 "41:\n"
3851 "add %[dst_ptr], %[dst_ptr], #8\n"
3852 // At this point we have completely finished writing values to the
3853 // destination matrix for the current block.
3854
3855 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
3856
3857 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
3858
3859 // Cast-and-saturate from int32 to int16
3860 "sqxtn v16.4h, v16.4s\n"
3861 "sqxtn2 v16.8h, v17.4s\n"
3862 "sqxtn v17.4h, v18.4s\n"
3863 "sqxtn2 v17.8h, v19.4s\n"
3864 "sqxtn v18.4h, v20.4s\n"
3865 "sqxtn2 v18.8h, v21.4s\n"
3866 "sqxtn v19.4h, v22.4s\n"
3867 "sqxtn2 v19.8h, v23.4s\n"
3868 "sqxtn v20.4h, v24.4s\n"
3869 "sqxtn2 v20.8h, v25.4s\n"
3870 "sqxtn v21.4h, v26.4s\n"
3871 "sqxtn2 v21.8h, v27.4s\n"
3872 "sqxtn v22.4h, v28.4s\n"
3873 "sqxtn2 v22.8h, v29.4s\n"
3874 "sqxtn v23.4h, v30.4s\n"
3875 "sqxtn2 v23.8h, v31.4s\n"
3876
3877 // At this point, v24 -- v31 aren't used anymore for the current block,
3878 // so we can start clearing these accumulators for the next block
3879 // (next iteration of the main loop).
3880 RUY_MAKE_ZERO(v24)
3881 RUY_MAKE_ZERO(v25)
3882 RUY_MAKE_ZERO(v26)
3883 RUY_MAKE_ZERO(v27)
3884 RUY_MAKE_ZERO(v28)
3885 RUY_MAKE_ZERO(v29)
3886 RUY_MAKE_ZERO(v30)
3887 RUY_MAKE_ZERO(v31)
3888
3889 // Add the destination zero point
3890 "dup v14.8h, v13.h[4]\n"
3891 "add v16.8h, v16.8h, v14.8h\n"
3892 "add v17.8h, v17.8h, v14.8h\n"
3893 "add v18.8h, v18.8h, v14.8h\n"
3894 "add v19.8h, v19.8h, v14.8h\n"
3895 "add v20.8h, v20.8h, v14.8h\n"
3896 "add v21.8h, v21.8h, v14.8h\n"
3897 "add v22.8h, v22.8h, v14.8h\n"
3898 "add v23.8h, v23.8h, v14.8h\n"
3899
3900 // Cast-and-saturate from int16 to uint8
3901 "sqxtn v16.8b, v16.8h\n"
3902 "sqxtn2 v16.16b, v17.8h\n"
3903 "sqxtn v17.8b, v18.8h\n"
3904 "sqxtn2 v17.16b, v19.8h\n"
3905 "sqxtn v18.8b, v20.8h\n"
3906 "sqxtn2 v18.16b, v21.8h\n"
3907 "sqxtn v19.8b, v22.8h\n"
3908 "sqxtn2 v19.16b, v23.8h\n"
3909
3910 // Load the clamp_min, clamp_max bounds
3911 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
3912 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
3913 "dup v14.16b, w2\n" // clamp_min
3914 "dup v15.16b, w3\n" // clamp_max
3915
3916 // Apply the clamp_min bound
3917 "smax v16.16b, v16.16b, v14.16b\n"
3918 "smax v17.16b, v17.16b, v14.16b\n"
3919 "smax v18.16b, v18.16b, v14.16b\n"
3920 "smax v19.16b, v19.16b, v14.16b\n"
3921
3922 // Apply the clamp_max bound
3923 "smin v16.16b, v16.16b, v15.16b\n"
3924 "smin v17.16b, v17.16b, v15.16b\n"
3925 "smin v18.16b, v18.16b, v15.16b\n"
3926 "smin v19.16b, v19.16b, v15.16b\n"
3927
3928 // Make it so that all of the final 8bit values are stored in the
3929 // first 64bits of 128bit NEON registers, so they can be stored
3930 // by 64bit st1 store instructions with byte alignment.
3931 "dup d20, v16.d[1]\n"
3932 "dup d21, v17.d[1]\n"
3933 "dup d22, v18.d[1]\n"
3934 "dup d23, v19.d[1]\n"
3935
3936 // Compute how much of the 8x8 block of destination 8bit values that
3937 // we have computed, fit in the destination matrix. Typically, all of
3938 // it fits, but when the destination matrix shape is not a multiple
3939 // of 8x8, there are some 8x8 blocks along the boundaries that do
3940 // not fit entirely.
3941 "sub w1, %w[dst_rows], %w[row]\n"
3942 "sub w2, %w[dst_cols], %w[col]\n"
3943 "mov w3, #8\n"
3944 "cmp w1, #8\n"
3945 // Compute w1 = how many rows of the 8x8 block fit
3946 "csel w1, w1, w3, le\n"
3947 "cmp w2, #8\n"
3948 // Compute w2 = how many cols of the 8x8 block fit
3949 "csel w2, w2, w3, le\n"
3950
3951 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
3952 "cmp w1, w3\n"
3953 "ccmp w2, w3, 0, eq\n"
3954 // Yes, all of the 8x8 block fits, go to fast path.
3955 "beq 130f\n"
3956 // Not all of the 8x8 block fits.
3957 // Set (x3 address, x4 stride) to write to dst_tmp_buf
3958 "mov x3, %[dst_tmp_buf]\n"
3959 "mov x4, #8\n"
3960 "b 131f\n"
3961 "130:\n"
3962 // Yes, all of the 8x8 block fits.
3963 // Set (x3 address, x4 stride) to write directly to destination matrix.
3964 "mov x3, %[dst_ptr]\n"
3965 "mov x4, x11\n"
3966 "131:\n"
3967
3968 // Write our 8bit values to the destination described by
3969 // (x3 address, x4 stride).
3970 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3971 "st1 {v16.8b}, [x3], x4\n"
3972 RUY_MAKE_ZERO(v16)
3973 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3974 "st1 {v20.8b}, [x3], x4\n"
3975 RUY_MAKE_ZERO(v20)
3976 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3977 "st1 {v17.8b}, [x3], x4\n"
3978 RUY_MAKE_ZERO(v17)
3979 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3980 "st1 {v21.8b}, [x3], x4\n"
3981 RUY_MAKE_ZERO(v21)
3982 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3983 "st1 {v18.8b}, [x3], x4\n"
3984 RUY_MAKE_ZERO(v18)
3985 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3986 "st1 {v22.8b}, [x3], x4\n"
3987 RUY_MAKE_ZERO(v22)
3988 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3989 "st1 {v19.8b}, [x3], x4\n"
3990 RUY_MAKE_ZERO(v19)
3991 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
3992 "st1 {v23.8b}, [x3], x4\n"
3993 RUY_MAKE_ZERO(v23)
3994
3995 // For the next block: perform the first few multiply-adds on the data
3996 // that we have already loaded.
3997 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
3998 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
3999 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
4000 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
4001
4002 // If all of the 8x8 block fits, we just finished writing it to the
4003 // destination, so we skip the next part.
4004 "beq 141f\n"
4005 // Not all of the 8x8 block fits in the destination matrix. We just
4006 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4007 // it to copy into the destination matrix the part that fits.
4008 "mov x3, %[dst_tmp_buf]\n"
4009 "mov x4, %[dst_ptr]\n"
4010 "mov w6, #0\n"
4011 "150:\n"
4012 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4013 "mov w5, #0\n"
4014 "151:\n"
4015 "ldrb w7, [x3, w5, uxtw]\n"
4016 "strb w7, [x4, w5, uxtw]\n"
4017 "add w5, w5, #1\n"
4018 "cmp w5, w1\n"
4019 "blt 151b\n"
4020 "add w6, w6, #1\n"
4021 "add x3, x3, #8\n"
4022 "add x4, x4, x11\n"
4023 "cmp w6, w2\n"
4024 "blt 150b\n"
4025 "141:\n"
4026 "add %[dst_ptr], %[dst_ptr], #8\n"
4027 // At this point we have completely finished writing values to the
4028 // destination matrix for the current block.
4029
4030 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4031
4032 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
4033
4034 // Add the destination zero point
4035 "dup v14.8h, v13.h[4]\n"
4036 "saddw v16.4s, v16.4s, v14.4h\n"
4037 "saddw v17.4s, v17.4s, v14.4h\n"
4038 "saddw v18.4s, v18.4s, v14.4h\n"
4039 "saddw v19.4s, v19.4s, v14.4h\n"
4040 "saddw v20.4s, v20.4s, v14.4h\n"
4041 "saddw v21.4s, v21.4s, v14.4h\n"
4042 "saddw v22.4s, v22.4s, v14.4h\n"
4043 "saddw v23.4s, v23.4s, v14.4h\n"
4044 "saddw v24.4s, v24.4s, v14.4h\n"
4045 "saddw v25.4s, v25.4s, v14.4h\n"
4046 "saddw v26.4s, v26.4s, v14.4h\n"
4047 "saddw v27.4s, v27.4s, v14.4h\n"
4048 "saddw v28.4s, v28.4s, v14.4h\n"
4049 "saddw v29.4s, v29.4s, v14.4h\n"
4050 "saddw v30.4s, v30.4s, v14.4h\n"
4051 "saddw v31.4s, v31.4s, v14.4h\n"
4052
4053 // Cast-and-saturate from int32 to int16
4054 "sqxtn v16.4h, v16.4s\n"
4055 "sqxtn2 v16.8h, v17.4s\n"
4056 "sqxtn v17.4h, v18.4s\n"
4057 "sqxtn2 v17.8h, v19.4s\n"
4058 "sqxtn v18.4h, v20.4s\n"
4059 "sqxtn2 v18.8h, v21.4s\n"
4060 "sqxtn v19.4h, v22.4s\n"
4061 "sqxtn2 v19.8h, v23.4s\n"
4062 "sqxtn v20.4h, v24.4s\n"
4063 "sqxtn2 v20.8h, v25.4s\n"
4064 "sqxtn v21.4h, v26.4s\n"
4065 "sqxtn2 v21.8h, v27.4s\n"
4066 "sqxtn v22.4h, v28.4s\n"
4067 "sqxtn2 v22.8h, v29.4s\n"
4068 "sqxtn v23.4h, v30.4s\n"
4069 "sqxtn2 v23.8h, v31.4s\n"
4070
4071 // At this point, v24 -- v31 aren't used anymore for the current block,
4072 // so we can start clearing these accumulators for the next block
4073 // (next iteration of the main loop).
4074 RUY_MAKE_ZERO(v24)
4075 RUY_MAKE_ZERO(v25)
4076 RUY_MAKE_ZERO(v26)
4077 RUY_MAKE_ZERO(v27)
4078 RUY_MAKE_ZERO(v28)
4079 RUY_MAKE_ZERO(v29)
4080 RUY_MAKE_ZERO(v30)
4081 RUY_MAKE_ZERO(v31)
4082
4083 // Load the clamp_min, clamp_max bounds
4084 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4085 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4086 "dup v14.8h, w2\n" // clamp_min
4087 "dup v15.8h, w3\n" // clamp_max
4088
4089 // Apply the clamp_min bound
4090 "smax v16.8h, v16.8h, v14.8h\n"
4091 "smax v17.8h, v17.8h, v14.8h\n"
4092 "smax v18.8h, v18.8h, v14.8h\n"
4093 "smax v19.8h, v19.8h, v14.8h\n"
4094 "smax v20.8h, v20.8h, v14.8h\n"
4095 "smax v21.8h, v21.8h, v14.8h\n"
4096 "smax v22.8h, v22.8h, v14.8h\n"
4097 "smax v23.8h, v23.8h, v14.8h\n"
4098 // Apply the clamp_max bound
4099 "smin v16.8h, v16.8h, v15.8h\n"
4100 "smin v17.8h, v17.8h, v15.8h\n"
4101 "smin v18.8h, v18.8h, v15.8h\n"
4102 "smin v19.8h, v19.8h, v15.8h\n"
4103 "smin v20.8h, v20.8h, v15.8h\n"
4104 "smin v21.8h, v21.8h, v15.8h\n"
4105 "smin v22.8h, v22.8h, v15.8h\n"
4106 "smin v23.8h, v23.8h, v15.8h\n"
4107
4108 // Compute how much of the 8x8 block of destination 16bit values that
4109 // we have computed, fit in the destination matrix. Typically, all of
4110 // it fits, but when the destination matrix shape is not a multiple
4111 // of 8x8, there are some 8x8 blocks along the boundaries that do
4112 // not fit entirely.
4113 "sub w1, %w[dst_rows], %w[row]\n"
4114 "sub w2, %w[dst_cols], %w[col]\n"
4115 "mov w3, #8\n"
4116 "cmp w1, #8\n"
4117 // Compute w1 = how many rows of the 8x8 block fit
4118 "csel w1, w1, w3, le\n"
4119 "cmp w2, #8\n"
4120 // Compute w1 = how many rows of the 8x8 block fit
4121 "csel w2, w2, w3, le\n"
4122
4123 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
4124 "cmp w1, w3\n"
4125 "ccmp w2, w3, 0, eq\n"
4126 // Yes, all of the 8x8 block fits, go to fast path.
4127 "beq 230f\n"
4128 // Not all of the 8x8 block fits.
4129 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4130 "mov x3, %[dst_tmp_buf]\n"
4131 "mov x4, #16\n"
4132 "b 231f\n"
4133 "230:\n"
4134 // Yes, all of the 8x8 block fits.
4135 // Set (x3 address, x4 stride) to write directly to destination matrix.
4136 "mov x3, %[dst_ptr]\n"
4137 "mov x4, x11\n"
4138 "231:\n"
4139
4140 // Write our 16bit values to the destination described by
4141 // (x3 address, x4 stride).
4142 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4143 "st1 {v16.8h}, [x3], x4\n"
4144 RUY_MAKE_ZERO(v16)
4145 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4146 "st1 {v17.8h}, [x3], x4\n"
4147 RUY_MAKE_ZERO(v17)
4148 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4149 "st1 {v18.8h}, [x3], x4\n"
4150 RUY_MAKE_ZERO(v18)
4151 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4152 "st1 {v19.8h}, [x3], x4\n"
4153 RUY_MAKE_ZERO(v19)
4154 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4155 "st1 {v20.8h}, [x3], x4\n"
4156 RUY_MAKE_ZERO(v20)
4157 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4158 "st1 {v21.8h}, [x3], x4\n"
4159 RUY_MAKE_ZERO(v21)
4160 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4161 "st1 {v22.8h}, [x3], x4\n"
4162 RUY_MAKE_ZERO(v22)
4163 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4164 "st1 {v23.8h}, [x3], x4\n"
4165 RUY_MAKE_ZERO(v23)
4166
4167 // For the next block: perform the first few multiply-adds on the data
4168 // that we have already loaded.
4169 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4170 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
4171 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
4172 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
4173
4174 // If all of the 8x8 block fits, we just finished writing it to the
4175 // destination, so we skip the next part.
4176 "beq 241f\n"
4177 // Not all of the 8x8 block fits in the destination matrix. We just
4178 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4179 // it to copy into the destination matrix the part that fits.
4180 "mov x3, %[dst_tmp_buf]\n"
4181 "mov x4, %[dst_ptr]\n"
4182 "mov w6, #0\n"
4183 "250:\n"
4184 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4185 "mov w5, #0\n"
4186 "251:\n"
4187 "ldrsh w7, [x3, x5, lsl #1]\n"
4188 "strh w7, [x4, x5, lsl #1]\n"
4189 "add w5, w5, #1\n"
4190 "cmp w5, w1\n"
4191 "blt 251b\n"
4192 "add w6, w6, #1\n"
4193 "add x3, x3, #16\n"
4194 "add x4, x4, x11\n"
4195 "cmp w6, w2\n"
4196 "blt 250b\n"
4197 "241:\n"
4198 "add %[dst_ptr], %[dst_ptr], #16\n"
4199 // At this point we have completely finished writing values to the
4200 // destination matrix for the current block.
4201
4202 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4203
4204 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
4205
4206 // Since the store type is the same as the accum type, no need for
4207 // downcast. There's also no need for clamp by min/max.
4208
4209 // Compute how much of the 8x8 block of destination 32it values that
4210 // we have computed, fit in the destination matrix. Typically, all of
4211 // it fits, but when the destination matrix shape is not a multiple
4212 // of 8x8, there are some 8x8 blocks along the boundaries that do
4213 // not fit entirely.
4214 "sub w1, %w[dst_rows], %w[row]\n"
4215 "sub w2, %w[dst_cols], %w[col]\n"
4216 "mov w3, #8\n"
4217 "cmp w1, #8\n"
4218 // Compute w1 = how many rows of the 8x8 block fit
4219 "csel w1, w1, w3, le\n"
4220 "cmp w2, #8\n"
4221 // Compute w1 = how many rows of the 8x8 block fit
4222 "csel w2, w2, w3, le\n"
4223
4224 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
4225 "cmp w1, w3\n"
4226 "ccmp w2, w3, 0, eq\n"
4227 // Yes, all of the 8x8 block fits, go to fast path.
4228 "beq 330f\n"
4229 // Not all of the 8x8 block fits.
4230 // Write to dst_tmp_buf
4231 "mov x3, %[dst_tmp_buf]\n"
4232 "st1 {v16.4s}, [x3], #16\n"
4233 RUY_MAKE_ZERO(v16)
4234 "st1 {v17.4s}, [x3], #16\n"
4235 RUY_MAKE_ZERO(v17)
4236 "st1 {v18.4s}, [x3], #16\n"
4237 RUY_MAKE_ZERO(v18)
4238 "st1 {v19.4s}, [x3], #16\n"
4239 RUY_MAKE_ZERO(v19)
4240 "st1 {v20.4s}, [x3], #16\n"
4241 RUY_MAKE_ZERO(v20)
4242 "st1 {v21.4s}, [x3], #16\n"
4243 RUY_MAKE_ZERO(v21)
4244 "st1 {v22.4s}, [x3], #16\n"
4245 RUY_MAKE_ZERO(v22)
4246 "st1 {v23.4s}, [x3], #16\n"
4247 RUY_MAKE_ZERO(v23)
4248 "st1 {v24.4s}, [x3], #16\n"
4249 RUY_MAKE_ZERO(v24)
4250 "st1 {v25.4s}, [x3], #16\n"
4251 RUY_MAKE_ZERO(v25)
4252 "st1 {v26.4s}, [x3], #16\n"
4253 RUY_MAKE_ZERO(v26)
4254 "st1 {v27.4s}, [x3], #16\n"
4255 RUY_MAKE_ZERO(v27)
4256 "st1 {v28.4s}, [x3], #16\n"
4257 RUY_MAKE_ZERO(v28)
4258 "st1 {v29.4s}, [x3], #16\n"
4259 RUY_MAKE_ZERO(v29)
4260 "st1 {v30.4s}, [x3], #16\n"
4261 RUY_MAKE_ZERO(v30)
4262 "st1 {v31.4s}, [x3], #16\n"
4263 RUY_MAKE_ZERO(v31)
4264
4265 "b 331f\n"
4266
4267 "330:\n"
4268 // Yes, all of the 8x8 block fits.
4269 "mov x4, %[dst_ptr]\n"
4270 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4271 "mov x3, x4\n"
4272 "st1 {v16.4s, v17.4s}, [x3], #32\n"
4273 RUY_MAKE_ZERO(v16)
4274 RUY_MAKE_ZERO(v17)
4275 "add x4, x4, x11\n"
4276 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4277 "mov x3, x4\n"
4278 "st1 {v18.4s, v19.4s}, [x3], #32\n"
4279 RUY_MAKE_ZERO(v18)
4280 RUY_MAKE_ZERO(v19)
4281 "add x4, x4, x11\n"
4282 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4283 "mov x3, x4\n"
4284 "st1 {v20.4s, v21.4s}, [x3], #32\n"
4285 RUY_MAKE_ZERO(v20)
4286 RUY_MAKE_ZERO(v21)
4287 "add x4, x4, x11\n"
4288 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4289 "mov x3, x4\n"
4290 "st1 {v22.4s, v23.4s}, [x3], #32\n"
4291 RUY_MAKE_ZERO(v22)
4292 RUY_MAKE_ZERO(v23)
4293 "add x4, x4, x11\n"
4294 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4295 "mov x3, x4\n"
4296 "st1 {v24.4s, v25.4s}, [x3], #32\n"
4297 RUY_MAKE_ZERO(v24)
4298 RUY_MAKE_ZERO(v25)
4299 "add x4, x4, x11\n"
4300 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4301 "mov x3, x4\n"
4302 "st1 {v26.4s, v27.4s}, [x3], #32\n"
4303 RUY_MAKE_ZERO(v26)
4304 RUY_MAKE_ZERO(v27)
4305 "add x4, x4, x11\n"
4306 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4307 "mov x3, x4\n"
4308 "st1 {v28.4s, v29.4s}, [x3], #32\n"
4309 RUY_MAKE_ZERO(v28)
4310 RUY_MAKE_ZERO(v29)
4311 "add x4, x4, x11\n"
4312 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4313 "mov x3, x4\n"
4314 "st1 {v30.4s, v31.4s}, [x3], #32\n"
4315 RUY_MAKE_ZERO(v30)
4316 RUY_MAKE_ZERO(v31)
4317
4318 "331:\n"
4319
4320 // For the next block: perform the first few multiply-adds on the data
4321 // that we have already loaded.
4322 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4323 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
4324 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
4325 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
4326
4327 // If all of the 8x8 block fits, we just finished writing it to the
4328 // destination, so we skip the next part.
4329 "beq 341f\n"
4330
4331 // Not all of the 8x8 block fits in the destination matrix. We just
4332 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4333 // it to copy into the destination matrix the part that fits.
4334 "mov x3, %[dst_tmp_buf]\n"
4335 "mov x4, %[dst_ptr]\n"
4336 "mov w6, #0\n"
4337 "350:\n"
4338 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4339 "mov w5, #0\n"
4340 "351:\n"
4341 "ldr w7, [x3, x5, lsl #2]\n"
4342 "str w7, [x4, x5, lsl #2]\n"
4343 "add w5, w5, #1\n"
4344 "cmp w5, w1\n"
4345 "blt 351b\n"
4346 "add w6, w6, #1\n"
4347 "add x3, x3, #32\n"
4348 "add x4, x4, x11\n"
4349 "cmp w6, w2\n"
4350 "blt 350b\n"
4351 "341:\n"
4352 "add %[dst_ptr], %[dst_ptr], #32\n"
4353 // At this point we have completely finished writing values to the
4354 // destination matrix for the current block.
4355
4356 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
4357
4358 // Reload some params --- we had used x5 -- x7 for a few other things
4359 // since the last time we had loaded them.
4360 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
4361 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
4362 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
4363
4364 // Move to the next block of the destination matrix, for the next iter
4365 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
4366 // been updated earlier.
4367 // Have we reached the end row?
4368 "cmp %w[row], w7\n"
4369 "beq 20f\n" // yes, end row.
4370 // Not end row. Move to the next row.
4371 "add %w[row], %w[row], #8\n"
4372 "b 21f\n"
4373 "20:\n"
4374 // Was already at end row.
4375 "mov %w[row], w6\n" // Move back to first row.
4376 "add %w[col], %w[col], #8\n" // Move to the next column.
4377 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
4378 "mov %[dst_ptr], %[dst_col_ptr]\n"
4379 "21:\n"
4380
4381 // Main loop exit condition: have we hit the end column?
4382 "cmp %w[col], w8\n"
4383
4384 // w1 is the number of levels of depth that we have already loaded
4385 // LHS and RHS data for. Corresponding to the initial ld1 instructions
4386 // above, this is currently 4.
4387 "mov w1, #4\n"
4388
4389 "ble 1b\n"
4390
4391 // clang-format on
4392
4393 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
4394 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
4395 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
4396 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
4397 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
4398 [dst_type_id] "r"(params.dst_type_id)
4399 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
4400 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
4401 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
4402 "v26", "v27", "v28", "v29", "v30", "v31");
4403 }
4404
4405 // Similar to the above 8-bit dotprod kernel, but specialized for the case of
4406 // RHS cols == 1.
4407 // Relevant target CPUs for this kernel include ARM Cortex-A76,
4408 // since these are 64-bit, out-of-order and with dotprod support.
Kernel8bitNeonDotprod1Col(const KernelParams8bit<8,8> & params)4409 void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
4410 profiler::ScopeLabel label("Kernel (kNeonDotprod)");
4411
4412 CheckOffsetsInKernelParams8bit(params);
4413
4414 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
4415 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
4416 const std::int8_t* lhs_ptr = lhs_col_ptr;
4417 const std::int8_t* rhs_ptr = rhs_col_ptr;
4418 void* dst_col_ptr = params.dst_base_ptr;
4419 void* dst_ptr = dst_col_ptr;
4420 int row = params.start_row;
4421 int col = params.start_col;
4422
4423 RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
4424
4425 // The asm kernel below has the following NEON register allocation:
4426 //
4427 // v16 -- v31 are int32 accumulators.
4428 // During accumulation, v0 -- v15 are used to load int8 data from LHS and
4429 // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
4430 // v3 are used to load a 4x8 block of RHS, like this:
4431 //
4432 // int8 RHS 4x1 block
4433 // /-------|
4434 // |v2.b[0]|
4435 // | ... |
4436 // |v2.b[3]|
4437 // \-------/
4438 // int8 LHS 8x4 block
4439 // /---------------------\ /--------|
4440 // |v0.b[0] ... v0.b[3] | |v16.s[0]|
4441 // | ... ... | | ... |
4442 // |v0.b[12] ... v0.b[15]| |v16.s[3]|
4443 // |v1.b[0] ... v1.b[3] | |v17.s[0]|
4444 // | ... ... | | ... |
4445 // |v1.b[12] ... v1.b[15]| |v17.s[3]|
4446 // \---------------------/ \--------/
4447 // int32 accumulators 8x1 block
4448 //
4449 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
4450 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
4451 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
4452 //
4453 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
4454 // unused, and v8 -- v15 are used for loading parameters used for the
4455 // post-accumulation part of the kernel.
4456 asm volatile(
4457 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
4458
4459 // clang-format off
4460
4461 // Load some parameters into registers.
4462 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
4463 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
4464 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
4465 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
4466 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
4467 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
4468 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
4469 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
4470
4471 // Load the first 32 bytes of LHS and RHS data.
4472 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4473 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4474 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4475 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4476
4477 // Clear accumulators.
4478 RUY_MAKE_ZERO(v16)
4479 RUY_MAKE_ZERO(v17)
4480
4481 // w1 is the number of levels of depth that we have already loaded
4482 // LHS and RHS data for. Corresponding to the initial ld1 instructions
4483 // above, this is currently 4.
4484 "mov w1, #4\n"
4485
4486 // Perform the first few multiply-adds on the data that we have already
4487 // loaded.
4488 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4489
4490 // Main loop of the whole GEMM, over rows and columns of the
4491 // destination matrix.
4492 "1:\n"
4493
4494 // Ordinary kernel inner loop (over depth), the simpler loop that the
4495 // above was an equivalent 4x-partially-unrolled version of.
4496
4497 // Reminder - w1 is how many levels of depth we have already loaded
4498 // data for, w12 is the total depth.
4499 "cmp w1, w12\n"
4500 "beq 79f\n"
4501
4502 "2:\n"
4503
4504 // Because of the data that we have already loaded, we can start the
4505 // loop body right away with some multiply-adds.
4506 // Each iteration of this loop advances by 4 levels of depth.
4507 "add w1, w1, #4\n"
4508 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4509 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
4510 // Loop termination condition.
4511 "cmp w1, w12\n"
4512 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4513 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4514 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4515 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4516
4517 "blt 2b\n"
4518
4519 "79:\n"
4520 // End of the inner loop on depth. Now perform the remaining
4521 // multiply-adds of the last 4 levels of depth, for which the LHS
4522 // and RHS data is already loaded.
4523
4524 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
4525
4526 // End of accumulation. The registers v16 -- v31 contain the final
4527 // int32 accumulator values of the current 8x8 destination block.
4528 // We now have to compute the final 8-bit values from these int32
4529 // accumulators, and advance to the next 8x8 block. We intertwine
4530 // these two aspects whenever possible for optimal pipelining, both
4531 // at the data flow level (prefetch data for next block as early as
4532 // possible) and instruction pipelining level (some of the next-block
4533 // work can dual-issue with some of the final work on the current
4534 // block).
4535
4536 // Logic to advance to the next block in preparation for the next
4537 // iteration of the main loop. For now, we only want to compute
4538 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
4539 // not yet ready to update the values of row and col, as we still need
4540 // the current values for the rest of the work on the current block.
4541
4542 "cmp %w[row], w7\n" // Have we finished the last row?
4543 "bge 4f\n" // If finished last row, go to 4
4544 // Not finished last row: then advance to next row.
4545 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
4546 "b 5f\n"
4547 "4:\n" // Finished last row...
4548 "mov %[lhs_col_ptr], x5\n" // Go back to first row
4549 // Now we need to advance to the next column. If we already
4550 // finished the last column, then in principle we are done, however
4551 // we can't just return here, as we need to allow the end work of the
4552 // current block to complete. The good news is that at this point it
4553 // doesn't matter what data we load for the next column, since
4554 // we will exit from the main loop below before actually storing
4555 // anything computed from that data.
4556 "cmp %w[col], w8\n" // Have we finished the last column?
4557 "bge 5f\n" // If yes, just carry on without updating the column pointer.
4558 // Not finished last column: then advance to next column.
4559 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
4560 "5:\n"
4561
4562 // Set the LHS and RHS data pointers to the start of the columns just
4563 // computed.
4564 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
4565 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
4566
4567 // Load some parameters needed for the end work on current block.
4568 "mvni v8.4s, #0\n"
4569 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
4570 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
4571 "ins v13.h[4], w4\n" // dst_zero_point
4572 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
4573 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
4574 "dup v9.4s, w3\n" // create prod_zp_depth_vec
4575 "add x5, x4, %x[row], lsl #2\n"
4576 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
4577 "csel x4, x4, x5, eq\n"
4578
4579 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
4580 "add x5, x1, %x[row], lsl #2\n"
4581
4582 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
4583 "csel x1, x1, x5, eq\n"
4584
4585 // Load 8 bias values.
4586 "ld1 {v14.4s}, [x1], #16\n"
4587 "ld1 {v15.4s}, [x1]\n"
4588
4589 // Now that we know what LHS and RHS data the next iteration of the
4590 // main loop will need to load, we start loading the first 32 bytes of
4591 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
4592 // in the rest of the work on the current block.
4593 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
4594 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
4595 "ld1 {v2.8b}, [%[rhs_ptr]]\n"
4596 "add %[rhs_ptr], %[rhs_ptr], #32\n"
4597
4598 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
4599 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4600 "add v14.4s, v14.4s, v9.4s\n"
4601 "add v15.4s, v15.4s, v9.4s\n"
4602
4603 // Perform the bias-addition (per the above, we have just folded into
4604 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
4605 "add v16.4s, v16.4s, v14.4s\n"
4606 "add v17.4s, v17.4s, v15.4s\n"
4607
4608 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
4609 "beq 401f\n"
4610 "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
4611 "add x3, x3, %x[col], lsl #2\n"
4612 "ld1 {v14.4s}, [x3], #16\n"
4613 "ld1 {v15.4s}, [x3]\n"
4614 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
4615 "dup v10.4s, w5\n" // create lhs_zero_point_vec
4616 // Subtract rhs_sums * lhs_zero_point, per
4617 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4618 "mls v16.4s, v10.4s, v14.s[0]\n"
4619 "mls v17.4s, v10.4s, v14.s[0]\n"
4620 "401:\n"
4621
4622 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
4623 "beq 402f\n"
4624 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
4625 "add x2, x2, %x[row], lsl #2\n"
4626 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
4627 // Load 4 lhs_sums values.
4628 "ld1 {v11.4s}, [x2], #16\n"
4629 "ld1 {v12.4s}, [x2]\n"
4630 "ins v13.s[1], w5\n" // rhs_zero_point
4631 // Compute lhs_sums * rhs_zero_point.
4632 "mul v11.4s, v11.4s, v13.s[1]\n"
4633 "mul v12.4s, v12.4s, v13.s[1]\n"
4634 // Subtract lhs_sums * rhs_zero_point, per
4635 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
4636 "sub v16.4s, v16.4s, v11.4s\n"
4637 "sub v17.4s, v17.4s, v12.4s\n"
4638
4639 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
4640 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
4641
4642 "402:\n"
4643
4644 // At this point we have computed the final int32 values. Now we
4645 // start down-quantizing them to obtain the final 8bit values from them.
4646
4647 // As part of this down-quantization, our int32 values will be
4648 // multiplied by a multiplier that has a fixed-point component and an
4649 // exponent component.
4650
4651 //Load the exponent part of the multiplier.
4652 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
4653 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
4654 "add x5, x1, %x[row], lsl #2\n"
4655 "csel x1, x1, x5, eq\n"
4656
4657 "ldr q9, [x1]\n"
4658 "ldr q10, [x1, #16]\n"
4659
4660 "smin v11.4s, v8.4s, v9.4s\n"
4661 "smin v12.4s, v8.4s, v10.4s\n"
4662 "sub v9.4s, v9.4s, v11.4s\n"
4663 "sub v10.4s, v10.4s, v12.4s\n"
4664
4665 // Apply the positive exponent part of the multiplier.
4666 "sshl v16.4s, v16.4s, v9.4s\n"
4667 "sshl v17.4s, v17.4s, v10.4s\n"
4668 "403:\n"
4669
4670 "ldr q14, [x4]\n" // multiplier_fixedpoint
4671 "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
4672
4673 // Apply the fixed-point part of the multiplier.
4674 "sqdmulh v16.4s, v16.4s, v14.4s\n"
4675 "sqdmulh v17.4s, v17.4s, v15.4s\n"
4676
4677 // Apply the negative exponent part of the multiplier.
4678 "srshl v16.4s, v16.4s, v11.4s\n"
4679 "srshl v17.4s, v17.4s, v12.4s\n"
4680
4681 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
4682 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
4683 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
4684 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
4685
4686 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
4687
4688 // Cast-and-saturate from int32 to int16
4689 "sqxtn v16.4h, v16.4s\n"
4690 "sqxtn2 v16.8h, v17.4s\n"
4691 // All data in v16 at this point.
4692
4693 // Add the destination zero point
4694 "dup v14.8h, v13.h[4]\n"
4695 "add v16.8h, v16.8h, v14.8h\n"
4696
4697 // Cast-and-saturate from int16 to uint8, leaving all data in the
4698 // lower half of v16.
4699 "sqxtun v16.8b, v16.8h\n"
4700
4701 // Load the clamp_min, clamp_max bounds
4702 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4703 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4704 "dup v14.16b, w2\n" // clamp_min
4705 "dup v15.16b, w3\n" // clamp_max
4706
4707 // Apply the clamp_min bound
4708 "umax v16.16b, v16.16b, v14.16b\n"
4709
4710 // Apply the clamp_max bound
4711 "umin v16.16b, v16.16b, v15.16b\n"
4712
4713 // Make it so that all of the final 8bit values are stored in the
4714 // first 64bits of 128bit NEON registers, so they can be stored
4715 // by 64bit st1 store instructions with byte alignment.
4716 "dup d20, v16.d[1]\n"
4717
4718 // Compute how much of the 8x1 block of destination 8bit values that
4719 // we have computed, fit in the destination matrix. Typically, all of
4720 // it fits, but when the destination matrix shape is not a multiple
4721 // of 8x1, there are some 8x1 blocks along the boundaries that do
4722 // not fit entirely.
4723 "sub w1, %w[dst_rows], %w[row]\n"
4724 "sub w2, %w[dst_cols], %w[col]\n"
4725 "mov w3, #8\n"
4726 "cmp w1, #8\n"
4727 // Compute w1 = how many rows of the 8x1 block fit
4728 "csel w1, w1, w3, le\n"
4729 "cmp w2, #8\n"
4730
4731 // Test if w1==8, i.e. if all of the 8x1 block fits.
4732 "cmp w1, w3\n"
4733 // Yes, all of the 8x1 block fits, go to fast path.
4734 "beq 30f\n"
4735 // Not all of the 8x1 block fits.
4736 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4737 "mov x3, %[dst_tmp_buf]\n"
4738 "mov x4, #8\n"
4739 "b 31f\n"
4740 "30:\n"
4741 // Yes, all of the 8x1 block fits.
4742 // Set (x3 address, x4 stride) to write directly to destination matrix.
4743 "mov x3, %[dst_ptr]\n"
4744 "mov x4, x11\n"
4745 "31:\n"
4746
4747 // Write our 8bit values to the destination
4748 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4749 "st1 {v16.8b}, [x3]\n"
4750 RUY_MAKE_ZERO(v16)
4751 RUY_MAKE_ZERO(v17)
4752
4753 // For the next block: perform the first few multiply-adds on the data
4754 // that we have already loaded.
4755 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4756
4757 // If all of the 8x8 block fits, we just finished writing it to the
4758 // destination, so we skip the next part.
4759 "beq 41f\n"
4760 // Not all of the 8x8 block fits in the destination matrix. We just
4761 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4762 // it to copy into the destination matrix the part that fits.
4763 "mov x3, %[dst_tmp_buf]\n"
4764 "mov x4, %[dst_ptr]\n"
4765 "mov w6, #0\n"
4766 "50:\n"
4767 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4768 "mov w5, #0\n"
4769 "51:\n"
4770 "ldrb w7, [x3, w5, uxtw]\n"
4771 "strb w7, [x4, w5, uxtw]\n"
4772 "add w5, w5, #1\n"
4773 "cmp w5, w1\n"
4774 "blt 51b\n"
4775 "41:\n"
4776 "add %[dst_ptr], %[dst_ptr], #8\n"
4777 // At this point we have completely finished writing values to the
4778 // destination matrix for the current block.
4779
4780 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4781
4782 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
4783
4784 // Cast-and-saturate from int32 to int16
4785 "sqxtn v16.4h, v16.4s\n"
4786 "sqxtn2 v16.8h, v17.4s\n"
4787
4788
4789 // Add the destination zero point
4790 "dup v14.8h, v13.h[4]\n"
4791 "add v16.8h, v16.8h, v14.8h\n"
4792
4793 // Cast-and-saturate from int16 to uint8
4794 "sqxtn v16.8b, v16.8h\n"
4795
4796 // Load the clamp_min, clamp_max bounds
4797 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4798 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4799 "dup v14.16b, w2\n" // clamp_min
4800 "dup v15.16b, w3\n" // clamp_max
4801
4802 // Apply the clamp_min bound
4803 "smax v16.16b, v16.16b, v14.16b\n"
4804
4805 // Apply the clamp_max bound
4806 "smin v16.16b, v16.16b, v15.16b\n"
4807
4808 // Make it so that all of the final 8bit values are stored in the
4809 // first 64bits of 128bit NEON registers, so they can be stored
4810 // by 64bit st1 store instructions with byte alignment.
4811 "dup d20, v16.d[1]\n"
4812
4813 // Compute how much of the 8x1 block of destination 8bit values that
4814 // we have computed, fit in the destination matrix. Typically, all of
4815 // it fits, but when the destination matrix shape is not a multiple
4816 // of 8x8, there are some 8x8 blocks along the boundaries that do
4817 // not fit entirely.
4818 "sub w1, %w[dst_rows], %w[row]\n"
4819 "sub w2, %w[dst_cols], %w[col]\n"
4820 "mov w3, #8\n"
4821 "cmp w1, #8\n"
4822 // Compute w1 = how many rows of the 8x1 block fit
4823 "csel w1, w1, w3, le\n"
4824 "cmp w2, #8\n"
4825
4826 // Test if w1==8, i.e. if all of the 8x1 block fits.
4827 "cmp w1, w3\n"
4828 // Yes, all of the 8x1 block fits, go to fast path.
4829 "beq 130f\n"
4830 // Not all of the 8x1 block fits.
4831 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4832 "mov x3, %[dst_tmp_buf]\n"
4833 "mov x4, #8\n"
4834 "b 131f\n"
4835 "130:\n"
4836 // Yes, all of the 8x8 block fits.
4837 // Set (x3 address, x4 stride) to write directly to destination matrix.
4838 "mov x3, %[dst_ptr]\n"
4839 "mov x4, x11\n"
4840 "131:\n"
4841
4842 // Write our 8bit values to the destination
4843 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4844 "st1 {v16.8b}, [x3]\n"
4845 RUY_MAKE_ZERO(v16)
4846 RUY_MAKE_ZERO(v17)
4847
4848 // For the next block: perform the first few multiply-adds on the data
4849 // that we have already loaded.
4850 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4851
4852 // If all of the 8x8 block fits, we just finished writing it to the
4853 // destination, so we skip the next part.
4854 "beq 141f\n"
4855 // Not all of the 8x8 block fits in the destination matrix. We just
4856 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4857 // it to copy into the destination matrix the part that fits.
4858 "mov x3, %[dst_tmp_buf]\n"
4859 "mov x4, %[dst_ptr]\n"
4860 "mov w6, #0\n"
4861 "150:\n"
4862 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4863 "mov w5, #0\n"
4864 "151:\n"
4865 "ldrb w7, [x3, w5, uxtw]\n"
4866 "strb w7, [x4, w5, uxtw]\n"
4867 "add w5, w5, #1\n"
4868 "cmp w5, w1\n"
4869 "blt 151b\n"
4870 "141:\n"
4871 "add %[dst_ptr], %[dst_ptr], #8\n"
4872 // At this point we have completely finished writing values to the
4873 // destination matrix for the current block.
4874
4875 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4876
4877 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
4878
4879 // Add the destination zero point
4880 "dup v14.8h, v13.h[4]\n"
4881 "saddw v16.4s, v16.4s, v14.4h\n"
4882 "saddw v17.4s, v17.4s, v14.4h\n"
4883
4884 // Cast-and-saturate from int32 to int16
4885 "sqxtn v16.4h, v16.4s\n"
4886 "sqxtn2 v16.8h, v17.4s\n"
4887
4888 // Load the clamp_min, clamp_max bounds
4889 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
4890 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
4891 "dup v14.8h, w2\n" // clamp_min
4892 "dup v15.8h, w3\n" // clamp_max
4893
4894 // Apply the clamp_min bound
4895 "smax v16.8h, v16.8h, v14.8h\n"
4896 // Apply the clamp_max bound
4897 "smin v16.8h, v16.8h, v15.8h\n"
4898
4899 // Compute how much of the 8x1 block of destination 16bit values that
4900 // we have computed, fit in the destination matrix. Typically, all of
4901 // it fits, but when the destination matrix shape is not a multiple
4902 // of 8x8, there are some 8x1 blocks along the boundaries that do
4903 // not fit entirely.
4904 "sub w1, %w[dst_rows], %w[row]\n"
4905 "sub w2, %w[dst_cols], %w[col]\n"
4906 "mov w3, #8\n"
4907 "cmp w1, #8\n"
4908 // Compute w1 = how many rows of the 8x1 block fit
4909 "csel w1, w1, w3, le\n"
4910 "cmp w2, #8\n"
4911
4912 // Test if w1==8, i.e. if all of the 8x8 block fits.
4913 "cmp w1, w3\n"
4914 // Yes, all of the 8x1 block fits, go to fast path.
4915 "beq 230f\n"
4916 // Not all of the 8x1 block fits.
4917 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4918 "mov x3, %[dst_tmp_buf]\n"
4919 "mov x4, #16\n"
4920 "b 231f\n"
4921 "230:\n"
4922 // Yes, all of the 8x1 block fits.
4923 // Set (x3 address, x4 stride) to write directly to destination matrix.
4924 "mov x3, %[dst_ptr]\n"
4925 "mov x4, x11\n"
4926 "231:\n"
4927
4928 // Write our 16bit values to the destination
4929 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4930 "st1 {v16.8h}, [x3]\n"
4931 RUY_MAKE_ZERO(v16)
4932 RUY_MAKE_ZERO(v17)
4933
4934 // For the next block: perform the first few multiply-adds on the data
4935 // that we have already loaded.
4936 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
4937
4938 // If all of the 8x1 block fits, we just finished writing it to the
4939 // destination, so we skip the next part.
4940 "beq 241f\n"
4941 // Not all of the 8x1 block fits in the destination matrix. We just
4942 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
4943 // it to copy into the destination matrix the part that fits.
4944 "mov x3, %[dst_tmp_buf]\n"
4945 "mov x4, %[dst_ptr]\n"
4946 "mov w6, #0\n"
4947 "250:\n"
4948 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
4949 "mov w5, #0\n"
4950 "251:\n"
4951 "ldrsh w7, [x3, x5, lsl #1]\n"
4952 "strh w7, [x4, x5, lsl #1]\n"
4953 "add w5, w5, #1\n"
4954 "cmp w5, w1\n"
4955 "blt 251b\n"
4956 "241:\n"
4957 "add %[dst_ptr], %[dst_ptr], #16\n"
4958 // At this point we have completely finished writing values to the
4959 // destination matrix for the current block.
4960
4961 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
4962
4963 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
4964
4965 // Since the store type is the same as the accum type, no need for
4966 // downcast. There's also no need for clamp by min/max.
4967
4968 // Compute how much of the 8x1 block of destination 32 bit values that
4969 // we have computed, fit in the destination matrix. Typically, all of
4970 // it fits, but when the destination matrix shape is not a multiple
4971 // of 8x1, there are some 8x1 blocks along the boundaries that do
4972 // not fit entirely.
4973 "sub w1, %w[dst_rows], %w[row]\n"
4974 "sub w2, %w[dst_cols], %w[col]\n"
4975 "mov w3, #8\n"
4976 "cmp w1, #8\n"
4977 // Compute w1 = how many rows of the 8x1 block fit
4978 "csel w1, w1, w3, le\n"
4979 "cmp w2, #8\n"
4980 // Compute w1 = how many rows of the 8x8 block fit
4981 "csel w2, w2, w3, le\n"
4982
4983 // Test if w1==8, i.e. if all of the 8x8 block fits.
4984 "cmp w1, w3\n"
4985 // Yes, all of the 8x1 block fits, go to fast path.
4986 "beq 330f\n"
4987 // Not all of the 8x1 block fits.
4988 // Set (x3 address, x4 stride) to write to dst_tmp_buf
4989 "mov x3, %[dst_tmp_buf]\n"
4990 "mov x4, #16\n"
4991
4992 // Write our 32bit values to the destination described by
4993 // (x3 address, x4 stride).
4994 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4995 "st1 {v16.4s}, [x3], x4\n"
4996 RUY_MAKE_ZERO(v16)
4997 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
4998 "st1 {v17.4s}, [x3], x4\n"
4999 RUY_MAKE_ZERO(v17)
5000
5001 "b 331f\n"
5002
5003 "330:\n"
5004 // Yes, all of the 8x1 block fits.
5005 // Set (x3 address, x4 stride) to write directly to destination matrix.
5006 "mov x4, %[dst_ptr]\n"
5007 "mov x3, x4\n"
5008
5009 // Write our 32bit values to the destination described by
5010 // (x3 address, x4 stride).
5011 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5012 "st1 {v16.4s, v17.4s}, [x3], #32\n"
5013 RUY_MAKE_ZERO(v16)
5014 RUY_MAKE_ZERO(v17)
5015
5016 "331:\n"
5017
5018 // For the next block: perform the first few multiply-adds on the data
5019 // that we have already loaded.
5020 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5021
5022 // If all of the 8x8 block fits, we just finished writing it to the
5023 // destination, so we skip the next part.
5024 "beq 341f\n"
5025
5026 // Not all of the 8x8 block fits in the destination matrix. We just
5027 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
5028 // it to copy into the destination matrix the part that fits.
5029 "mov x3, %[dst_tmp_buf]\n"
5030 "mov x4, %[dst_ptr]\n"
5031 "mov w6, #0\n"
5032 "350:\n"
5033 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5034 "mov w5, #0\n"
5035 "351:\n"
5036 "ldr w7, [x3, x5, lsl #2]\n"
5037 "str w7, [x4, x5, lsl #2]\n"
5038 "add w5, w5, #1\n"
5039 "cmp w5, w1\n"
5040 "blt 351b\n"
5041 "341:\n"
5042 "add %[dst_ptr], %[dst_ptr], #32\n"
5043 // At this point we have completely finished writing values to the
5044 // destination matrix for the current block.
5045
5046 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
5047
5048 // Reload some params --- we had used x5 -- x7 for a few other things
5049 // since the last time we had loaded them.
5050 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
5051 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
5052 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
5053
5054 // Move to the next block of the destination matrix, for the next iter
5055 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
5056 // been updated earlier.
5057 // Have we reached the end row?
5058 "cmp %w[row], w7\n"
5059 "beq 20f\n" // yes, end row.
5060 // Not end row. Move to the next row.
5061 "add %w[row], %w[row], #8\n"
5062 "b 21f\n"
5063 "20:\n"
5064 // Was already at end row.
5065 "mov %w[row], w6\n" // Move back to first row.
5066 "add %w[col], %w[col], #8\n" // Move to the next column.
5067 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
5068 "mov %[dst_ptr], %[dst_col_ptr]\n"
5069 "21:\n"
5070
5071 // Main loop exit condition: have we hit the end column?
5072 "cmp %w[col], w8\n"
5073
5074 // w1 is the number of levels of depth that we have already loaded
5075 // LHS and RHS data for. Corresponding to the initial ld1 instructions
5076 // above, this is currently 4.
5077 "mov w1, #4\n"
5078
5079 "ble 1b\n"
5080
5081 // clang-format on
5082
5083 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
5084 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
5085 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
5086 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
5087 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
5088 [dst_type_id] "r"(params.dst_type_id)
5089 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
5090 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
5091 "v13", "v14", "v15", "v16", "v17");
5092 }
5093
5094 // Variant of the above Kernel8bitNeonDotprod, tuned for in-order
5095 // CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1,
5096 // since these are 64-bit and support dotprod.
5097 //
5098 // While this kernel does not have a direct equivalent in gemmlowp, it was
5099 // developed based on insights that David Mansell at ARM shared with their
5100 // contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful
5101 // comments. Specifically, see this comment about tuning for Cortex-A55r1:
5102 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8,8> & params)5103 void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
5104 profiler::ScopeLabel label(
5105 "Kernel (kNeonDotprod, optimized for in-order cores)");
5106
5107 CheckOffsetsInKernelParams8bit(params);
5108
5109 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
5110 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
5111 const std::int8_t* lhs_ptr = lhs_col_ptr;
5112 const std::int8_t* rhs_ptr = rhs_col_ptr;
5113 void* dst_col_ptr = params.dst_base_ptr;
5114 void* dst_ptr = dst_col_ptr;
5115 int row = params.start_row;
5116 int col = params.start_col;
5117
5118 // The asm kernel below has the following NEON register allocation:
5119 //
5120 // v16 -- v31 are int32 accumulators.
5121 // During accumulation, v0 -- v3 are used to load int8 data from LHS and
5122 // RHS.
5123 //
5124 // int8 RHS 4x8 block
5125 // /-----------------------------------------|
5126 // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
5127 // | ... ... |
5128 // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
5129 // \-----------------------------------------/
5130 // int8 LHS 8x4 block
5131 // /---------------------\ /-----------------------------------------|
5132 // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
5133 // | ... ... | | ... ... |
5134 // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
5135 // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
5136 // | ... ... | | ... ... |
5137 // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
5138 // \---------------------/ \-----------------------------------------/
5139 // int32 accumulators 8x8 block
5140 //
5141 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
5142 // we did not observe a benefit of such partial unrolling on in-order CPUs.
5143 //
5144 // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for
5145 // the post-accumulation part of the kernel.
5146 asm volatile(
5147 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
5148
5149 // clang-format off
5150
5151 // Load some parameters into registers.
5152 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
5153 RUY_MAKE_ZERO(v16)
5154 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
5155 RUY_MAKE_ZERO(v17)
5156 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
5157 RUY_MAKE_ZERO(v18)
5158 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
5159 RUY_MAKE_ZERO(v19)
5160 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
5161 RUY_MAKE_ZERO(v20)
5162 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
5163 RUY_MAKE_ZERO(v21)
5164 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
5165 RUY_MAKE_ZERO(v22)
5166 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
5167
5168 // Load the first 32 bytes of LHS and RHS data.
5169 "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
5170 "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
5171 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
5172 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
5173
5174 // Clear accumulators.
5175 RUY_MAKE_ZERO(v23)
5176 RUY_MAKE_ZERO(v24)
5177 RUY_MAKE_ZERO(v25)
5178 RUY_MAKE_ZERO(v26)
5179 RUY_MAKE_ZERO(v27)
5180 // Perform the first few multiply-adds on the data that we have already
5181 // loaded.
5182 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5183 RUY_MAKE_ZERO(v28)
5184 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5185 RUY_MAKE_ZERO(v29)
5186 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5187 RUY_MAKE_ZERO(v30)
5188 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5189 RUY_MAKE_ZERO(v31)
5190
5191
5192 "1:\n"
5193
5194 "add x5, %[lhs_ptr], x12, lsl #3\n"
5195 "sub x5, x5, #32\n"
5196 "cmp %[lhs_ptr], x5\n"
5197
5198 "beq 79f\n"
5199
5200 // Main accumulation loop
5201 "2:\n"
5202 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
5203 "ldr x1, [%[lhs_ptr], #8]\n"
5204 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
5205 "ldr x3, [%[rhs_ptr], #8]\n"
5206 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
5207 "ldr x4, [%[rhs_ptr], #24]\n"
5208 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
5209 "ldr d0, [%[lhs_ptr], #0]\n"
5210 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
5211 "ins v0.d[1], x1\n"
5212 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
5213 "ldr x2, [%[lhs_ptr], #24]\n"
5214 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
5215 "add %[lhs_ptr], %[lhs_ptr], #32\n"
5216 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
5217 "ldr d2, [%[rhs_ptr], #0]\n"
5218 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
5219 "ins v2.d[1], x3\n"
5220 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
5221 "cmp %[lhs_ptr], x5\n"
5222 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
5223 "add %[rhs_ptr], %[rhs_ptr], #32\n"
5224 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
5225 "ldr d3, [%[rhs_ptr], #-16]\n"
5226 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5227 "ldr d1, [%[lhs_ptr], #-16]\n"
5228 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5229 "ins v3.d[1], x4\n"
5230 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5231 "ins v1.d[1], x2\n"
5232 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5233 "blt 2b\n"
5234
5235 // Last accumulation steps, nothing left to load.
5236 "79:\n"
5237 ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
5238 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
5239 ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
5240 "cmp %w[row], w7\n" // Have we finished the last row?
5241 ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
5242 ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
5243 ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
5244 ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
5245 ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
5246 ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
5247 ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
5248 ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
5249 ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
5250 ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
5251
5252 // End of accumulation. The registers v16 -- v31 contain the final
5253 // int32 accumulator values of the current 8x8 destination block.
5254 // We now have to compute the final 8-bit values from these int32
5255 // accumulators, and advance to the next 8x8 block. We intertwine
5256 // these two aspects whenever possible for optimal pipelining, both
5257 // at the data flow level (prefetch data for next block as early as
5258 // possible) and instruction pipelining level (some of the next-block
5259 // work can dual-issue with some of the final work on the current
5260 // block).
5261
5262 // Logic to advance to the next block in preparation for the next
5263 // iteration of the main loop. For now, we only want to compute
5264 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
5265 // not yet ready to update the values of row and col, as we still need
5266 // the current values for the rest of the work on the current block.
5267
5268 "bge 4f\n" // If finished last row, go to 4
5269 // Not finished last row: then advance to next row.
5270 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
5271 "b 5f\n"
5272 "4:\n" // Finished last row...
5273 "mov %[lhs_col_ptr], x5\n" // Go back to first row
5274 // Now we need to advance to the next column. If we already
5275 // finished the last column, then in principle we are done, however
5276 // we can't just return here, as we need to allow the end work of the
5277 // current block to complete. The good news is that at this point it
5278 // doesn't matter what data we load for the next column, since
5279 // we will exit from the main loop below before actually storing
5280 // anything computed from that data.
5281 "cmp %w[col], w8\n" // Have we finished the last column?
5282 "bge 5f\n" // If yes, just carry on without updating the column pointer.
5283 // Not finished last column: then advance to next column.
5284 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
5285 "5:\n"
5286
5287 // Set the LHS and RHS data pointers to the start of the columns just
5288 // computed.
5289 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
5290 // Load some parameters needed for the end work on current block.
5291 "mvni v8.4s, #0\n"
5292 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
5293 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
5294 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
5295 "dup v9.4s, w3\n" // create prod_zp_depth_vec
5296
5297 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
5298 // Determine the channel index.
5299 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
5300 "csel w3, %w[row], %w[col], eq\n"
5301
5302 // Offset the bias pointer as needed given the current row, col.
5303 "add x5, x1, x3, lsl #2\n"
5304
5305 // If there is no bias, use no offset, just address the passed zero
5306 // data.
5307 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
5308 "csel x1, x1, x5, eq\n"
5309
5310 // Load 8 bias values.
5311 "ld1 {v14.2s}, [x1], #8\n"
5312 "ldr x5, [x1], #8\n"
5313 "ins v14.d[1], x5\n"
5314 "ld1 {v15.2s}, [x1], #8\n"
5315 "ldr x5, [x1], #8\n"
5316 "ins v15.d[1], x5\n"
5317
5318 // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
5319 // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5320 "add v14.4s, v14.4s, v9.4s\n"
5321 "add v15.4s, v15.4s, v9.4s\n"
5322 // Perform the bias-addition (per the above, we have just folded into
5323 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
5324 // Jump based on channel dimension.
5325 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
5326 "bne 6f\n"
5327 // Case where channels are rows
5328 "add v16.4s, v16.4s, v14.4s\n"
5329 "add v17.4s, v17.4s, v15.4s\n"
5330 "add v18.4s, v18.4s, v14.4s\n"
5331 "add v19.4s, v19.4s, v15.4s\n"
5332 "add v20.4s, v20.4s, v14.4s\n"
5333 "add v21.4s, v21.4s, v15.4s\n"
5334 "add v22.4s, v22.4s, v14.4s\n"
5335 "add v23.4s, v23.4s, v15.4s\n"
5336 "add v24.4s, v24.4s, v14.4s\n"
5337 "add v25.4s, v25.4s, v15.4s\n"
5338 "add v26.4s, v26.4s, v14.4s\n"
5339 "add v27.4s, v27.4s, v15.4s\n"
5340 "add v28.4s, v28.4s, v14.4s\n"
5341 "add v29.4s, v29.4s, v15.4s\n"
5342 "add v30.4s, v30.4s, v14.4s\n"
5343 "add v31.4s, v31.4s, v15.4s\n"
5344 "b 7f\n"
5345
5346 "6:\n"
5347 // Case where channels are columns
5348 "dup v10.4s, v14.s[0]\n"
5349 "dup v11.4s, v14.s[1]\n"
5350 "add v16.4s, v16.4s, v10.4s\n"
5351 "dup v12.4s, v14.s[2]\n"
5352 "add v17.4s, v17.4s, v10.4s\n"
5353 "dup v13.4s, v14.s[3]\n"
5354 "add v18.4s, v18.4s, v11.4s\n"
5355 "dup v10.4s, v15.s[0]\n"
5356 "add v19.4s, v19.4s, v11.4s\n"
5357 "dup v11.4s, v15.s[1]\n"
5358 "add v20.4s, v20.4s, v12.4s\n"
5359 "add v21.4s, v21.4s, v12.4s\n"
5360 "dup v12.4s, v15.s[2]\n"
5361 "add v22.4s, v22.4s, v13.4s\n"
5362 "add v23.4s, v23.4s, v13.4s\n"
5363 "dup v13.4s, v15.s[3]\n"
5364 "add v24.4s, v24.4s, v10.4s\n"
5365 "add v25.4s, v25.4s, v10.4s\n"
5366 "add v26.4s, v26.4s, v11.4s\n"
5367 "add v27.4s, v27.4s, v11.4s\n"
5368 "add v28.4s, v28.4s, v12.4s\n"
5369 "add v29.4s, v29.4s, v12.4s\n"
5370 "add v30.4s, v30.4s, v13.4s\n"
5371 "add v31.4s, v31.4s, v13.4s\n"
5372 "7:\n"
5373
5374 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
5375 "beq 401f\n"
5376 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
5377 "dup v10.4s, w5\n" // create lhs_zero_point_vec
5378 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
5379 "add x5, x5, %x[col], lsl #2\n"
5380 // Load 8 rhs_sums values.
5381 "ld1 {v14.2s}, [x5], #8\n"
5382 "ldr x7, [x5], #8\n"
5383 "ld1 {v15.2s}, [x5], #8\n"
5384 "ins v14.d[1], x7\n"
5385 "ldr x7, [x5], #8\n"
5386 "ins v15.d[1], x7\n"
5387 // Subtract rhs_sums * lhs_zero_point, per
5388 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5389 "mls v16.4s, v10.4s, v14.s[0]\n"
5390 "mls v17.4s, v10.4s, v14.s[0]\n"
5391 "mls v18.4s, v10.4s, v14.s[1]\n"
5392 "mls v19.4s, v10.4s, v14.s[1]\n"
5393 "mls v20.4s, v10.4s, v14.s[2]\n"
5394 "mls v21.4s, v10.4s, v14.s[2]\n"
5395 "mls v22.4s, v10.4s, v14.s[3]\n"
5396 "mls v23.4s, v10.4s, v14.s[3]\n"
5397 "mls v24.4s, v10.4s, v15.s[0]\n"
5398 "mls v25.4s, v10.4s, v15.s[0]\n"
5399 "mls v26.4s, v10.4s, v15.s[1]\n"
5400 "mls v27.4s, v10.4s, v15.s[1]\n"
5401 "mls v28.4s, v10.4s, v15.s[2]\n"
5402 "mls v29.4s, v10.4s, v15.s[2]\n"
5403 "mls v30.4s, v10.4s, v15.s[3]\n"
5404 "mls v31.4s, v10.4s, v15.s[3]\n"
5405 "401:\n"
5406
5407 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
5408 "beq 402f\n"
5409 "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
5410 "add x2, x2, %x[row], lsl #2\n"
5411 "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
5412 "ins v13.s[1], w5\n" // rhs_zero_point
5413 // Load 8 lhs_sums values.
5414 "ld1 {v11.2s}, [x2], #8\n"
5415 "ldr x4, [x2], #8\n"
5416 "ins v11.d[1], x4\n"
5417 "ld1 {v12.2s}, [x2], #8\n"
5418 "ldr x4, [x2], #8\n"
5419 "ins v12.d[1], x4\n"
5420 // Compute lhs_sums * rhs_zero_point.
5421 "mul v11.4s, v11.4s, v13.s[1]\n"
5422 "mul v12.4s, v12.4s, v13.s[1]\n"
5423 // Subtract lhs_sums * rhs_zero_point, per
5424 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
5425 "sub v16.4s, v16.4s, v11.4s\n"
5426 "sub v17.4s, v17.4s, v12.4s\n"
5427 "sub v18.4s, v18.4s, v11.4s\n"
5428 "sub v19.4s, v19.4s, v12.4s\n"
5429 "sub v20.4s, v20.4s, v11.4s\n"
5430 "sub v21.4s, v21.4s, v12.4s\n"
5431 "sub v22.4s, v22.4s, v11.4s\n"
5432 "sub v23.4s, v23.4s, v12.4s\n"
5433 "sub v24.4s, v24.4s, v11.4s\n"
5434 "sub v25.4s, v25.4s, v12.4s\n"
5435 "sub v26.4s, v26.4s, v11.4s\n"
5436 "sub v27.4s, v27.4s, v12.4s\n"
5437 "sub v28.4s, v28.4s, v11.4s\n"
5438 "sub v29.4s, v29.4s, v12.4s\n"
5439 "sub v30.4s, v30.4s, v11.4s\n"
5440 "sub v31.4s, v31.4s, v12.4s\n"
5441
5442 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
5443 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
5444
5445 "402:\n"
5446
5447 // At this point we have computed the final int32 values. Now we
5448 // start down-quantizing them to obtain the final 8bit values from them.
5449
5450 // As part of this down-quantization, our int32 values will be
5451 // multiplied by a multiplier that has a fixed-point component and an
5452 // exponent component.
5453
5454 //Load the exponent part of the multiplier.
5455 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
5456 // Compute the multiplier_exponent pointer
5457 "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
5458 "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
5459 "add x5, x1, x3, lsl #2\n"
5460 "csel x1, x1, x5, eq\n"
5461 // Load multiplier_exponent
5462 "ldr q9, [x1]\n"
5463 "ldr q10, [x1, #16]\n"
5464 // Separate positive and negative exponents
5465 "smin v11.4s, v8.4s, v9.4s\n"
5466 "smin v12.4s, v8.4s, v10.4s\n"
5467 "sub v9.4s, v9.4s, v11.4s\n"
5468 "sub v10.4s, v10.4s, v12.4s\n"
5469
5470 // Compute the multiplier_fixedpoint pointer
5471 "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
5472 "add x5, x4, x3, lsl #2\n"
5473 "csel x4, x4, x5, eq\n"
5474 // Load multiplier_fixedpoint
5475 "ldr q14, [x4]\n"
5476 "ldr q15, [x4, #16]\n"
5477
5478 // Jump based on channel dimension.
5479 "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
5480 "bne 8f\n"
5481 // Case where channels are rows
5482
5483 // Apply the positive exponent part of the multiplier.
5484 "sshl v16.4s, v16.4s, v9.4s\n"
5485 "sshl v17.4s, v17.4s, v10.4s\n"
5486 "sshl v18.4s, v18.4s, v9.4s\n"
5487 "sshl v19.4s, v19.4s, v10.4s\n"
5488 "sshl v20.4s, v20.4s, v9.4s\n"
5489 "sshl v21.4s, v21.4s, v10.4s\n"
5490 "sshl v22.4s, v22.4s, v9.4s\n"
5491 "sshl v23.4s, v23.4s, v10.4s\n"
5492 "sshl v24.4s, v24.4s, v9.4s\n"
5493 "sshl v25.4s, v25.4s, v10.4s\n"
5494 "sshl v26.4s, v26.4s, v9.4s\n"
5495 "sshl v27.4s, v27.4s, v10.4s\n"
5496 "sshl v28.4s, v28.4s, v9.4s\n"
5497 "sshl v29.4s, v29.4s, v10.4s\n"
5498 "sshl v30.4s, v30.4s, v9.4s\n"
5499 "sshl v31.4s, v31.4s, v10.4s\n"
5500 "10:\n"
5501
5502 // Apply the fixed-point part of the multiplier.
5503 //
5504 // ... and, interleaved into that:
5505 // Now that we know what LHS and RHS data the next iteration of the
5506 // main loop will need to load, we start loading the first 32 bytes of
5507 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
5508 // in the rest of the work on the current block.
5509 "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
5510 "sqdmulh v16.4s, v16.4s, v14.4s\n"
5511 "ldr x1, [%[lhs_ptr]], #8\n"
5512 "sqdmulh v17.4s, v17.4s, v15.4s\n"
5513 "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
5514 "sqdmulh v18.4s, v18.4s, v14.4s\n"
5515 "ldr x2, [%[lhs_ptr]], #8\n"
5516 "sqdmulh v19.4s, v19.4s, v15.4s\n"
5517 "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
5518 "sqdmulh v20.4s, v20.4s, v14.4s\n"
5519 "ldr x5, [%[rhs_ptr]], #8\n"
5520 "sqdmulh v21.4s, v21.4s, v15.4s\n"
5521 "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
5522 "sqdmulh v22.4s, v22.4s, v14.4s\n"
5523 "ldr x6, [%[rhs_ptr]], #8\n"
5524 "sqdmulh v23.4s, v23.4s, v15.4s\n"
5525 "sqdmulh v24.4s, v24.4s, v14.4s\n"
5526 "sqdmulh v25.4s, v25.4s, v15.4s\n"
5527 "sqdmulh v26.4s, v26.4s, v14.4s\n"
5528 "sqdmulh v27.4s, v27.4s, v15.4s\n"
5529 "sqdmulh v28.4s, v28.4s, v14.4s\n"
5530 "sqdmulh v29.4s, v29.4s, v15.4s\n"
5531 "sqdmulh v30.4s, v30.4s, v14.4s\n"
5532 "sqdmulh v31.4s, v31.4s, v15.4s\n"
5533
5534 // Apply the negative exponent part of the multiplier.
5535 "srshl v16.4s, v16.4s, v11.4s\n"
5536 "srshl v17.4s, v17.4s, v12.4s\n"
5537 "srshl v18.4s, v18.4s, v11.4s\n"
5538 "srshl v19.4s, v19.4s, v12.4s\n"
5539 "srshl v20.4s, v20.4s, v11.4s\n"
5540 "srshl v21.4s, v21.4s, v12.4s\n"
5541 "srshl v22.4s, v22.4s, v11.4s\n"
5542 "srshl v23.4s, v23.4s, v12.4s\n"
5543 "srshl v24.4s, v24.4s, v11.4s\n"
5544 "srshl v25.4s, v25.4s, v12.4s\n"
5545 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
5546 "srshl v26.4s, v26.4s, v11.4s\n"
5547 "ins v13.h[4], w4\n" // dst_zero_point
5548 "srshl v27.4s, v27.4s, v12.4s\n"
5549 "ins v0.d[1], x1\n"
5550 "srshl v28.4s, v28.4s, v11.4s\n"
5551 "ins v1.d[1], x2\n"
5552 "srshl v29.4s, v29.4s, v12.4s\n"
5553 "ins v2.d[1], x5\n"
5554 "srshl v30.4s, v30.4s, v11.4s\n"
5555 "ins v3.d[1], x6\n"
5556 "srshl v31.4s, v31.4s, v12.4s\n"
5557 "b 9f\n"
5558
5559 "8:\n"
5560 // Case where channels are columns
5561
5562 // Apply the positive exponent part of the multiplier.
5563 "dup v4.4s, v9.s[0]\n"
5564 "dup v5.4s, v9.s[1]\n"
5565 "sshl v16.4s, v16.4s, v4.4s\n"
5566 "dup v6.4s, v9.s[2]\n"
5567 "sshl v17.4s, v17.4s, v4.4s\n"
5568 "dup v7.4s, v9.s[3]\n"
5569 "sshl v18.4s, v18.4s, v5.4s\n"
5570 "dup v4.4s, v10.s[0]\n"
5571 "sshl v19.4s, v19.4s, v5.4s\n"
5572 "dup v5.4s, v10.s[1]\n"
5573 "sshl v20.4s, v20.4s, v6.4s\n"
5574 "sshl v21.4s, v21.4s, v6.4s\n"
5575 "dup v6.4s, v10.s[2]\n"
5576 "sshl v22.4s, v22.4s, v7.4s\n"
5577 "sshl v23.4s, v23.4s, v7.4s\n"
5578 "dup v7.4s, v10.s[3]\n"
5579 "sshl v24.4s, v24.4s, v4.4s\n"
5580 "sshl v25.4s, v25.4s, v4.4s\n"
5581 "sshl v26.4s, v26.4s, v5.4s\n"
5582 "sshl v27.4s, v27.4s, v5.4s\n"
5583 "sshl v28.4s, v28.4s, v6.4s\n"
5584 "sshl v29.4s, v29.4s, v6.4s\n"
5585 "sshl v30.4s, v30.4s, v7.4s\n"
5586 "sshl v31.4s, v31.4s, v7.4s\n"
5587 "11:\n"
5588
5589 // Apply the fixed-point part of the multiplier.
5590 //
5591 // ... and, interleaved into that:
5592 // Now that we know what LHS and RHS data the next iteration of the
5593 // main loop will need to load, we start loading the first 32 bytes of
5594 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
5595 // in the rest of the work on the current block.
5596 "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
5597 "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
5598 "ldr x1, [%[lhs_ptr]], #8\n"
5599 "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
5600 "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
5601 "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
5602 "ldr x2, [%[lhs_ptr]], #8\n"
5603 "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
5604 "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
5605 "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
5606 "ldr x5, [%[rhs_ptr]], #8\n"
5607 "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
5608 "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
5609 "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
5610 "ldr x6, [%[rhs_ptr]], #8\n"
5611 "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
5612 "dup v4.4s, v11.s[0]\n"
5613 "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
5614 "dup v5.4s, v11.s[1]\n"
5615 "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
5616 "dup v6.4s, v11.s[2]\n"
5617 "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
5618 "dup v7.4s, v11.s[3]\n"
5619 "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
5620 "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
5621 "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
5622 "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
5623 "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
5624
5625 // Apply the negative exponent part of the multiplier.
5626 "srshl v16.4s, v16.4s, v4.4s\n"
5627 "srshl v17.4s, v17.4s, v4.4s\n"
5628 "dup v4.4s, v12.s[0]\n"
5629 "srshl v18.4s, v18.4s, v5.4s\n"
5630 "srshl v19.4s, v19.4s, v5.4s\n"
5631 "dup v5.4s, v12.s[1]\n"
5632 "srshl v20.4s, v20.4s, v6.4s\n"
5633 "srshl v21.4s, v21.4s, v6.4s\n"
5634 "dup v6.4s, v12.s[2]\n"
5635 "srshl v22.4s, v22.4s, v7.4s\n"
5636 "srshl v23.4s, v23.4s, v7.4s\n"
5637 "dup v7.4s, v12.s[3]\n"
5638 "srshl v24.4s, v24.4s, v4.4s\n"
5639 "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
5640 "srshl v25.4s, v25.4s, v4.4s\n"
5641 "ins v13.h[4], w4\n" // dst_zero_point
5642 "srshl v26.4s, v26.4s, v5.4s\n"
5643 "ins v0.d[1], x1\n"
5644 "srshl v27.4s, v27.4s, v5.4s\n"
5645 "ins v1.d[1], x2\n"
5646 "srshl v28.4s, v28.4s, v6.4s\n"
5647 "ins v2.d[1], x5\n"
5648 "srshl v29.4s, v29.4s, v6.4s\n"
5649 "ins v3.d[1], x6\n"
5650 "srshl v30.4s, v30.4s, v7.4s\n"
5651 "srshl v31.4s, v31.4s, v7.4s\n"
5652 "9:\n"
5653
5654 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
5655 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
5656 "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
5657 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
5658
5659 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
5660
5661 // Cast-and-saturate from int32 to int16
5662 "sqxtn v16.4h, v16.4s\n"
5663 "sqxtn2 v16.8h, v17.4s\n"
5664 "sqxtn v17.4h, v18.4s\n"
5665 "sqxtn2 v17.8h, v19.4s\n"
5666 "sqxtn v18.4h, v20.4s\n"
5667 "sqxtn2 v18.8h, v21.4s\n"
5668 "sqxtn v19.4h, v22.4s\n"
5669 "sqxtn2 v19.8h, v23.4s\n"
5670 "sqxtn v20.4h, v24.4s\n"
5671 "sqxtn2 v20.8h, v25.4s\n"
5672 "sqxtn v21.4h, v26.4s\n"
5673 "sqxtn2 v21.8h, v27.4s\n"
5674 "sqxtn v22.4h, v28.4s\n"
5675 "sqxtn2 v22.8h, v29.4s\n"
5676 "sqxtn v23.4h, v30.4s\n"
5677 "sqxtn2 v23.8h, v31.4s\n"
5678
5679 // Destination zero_point
5680 "dup v14.8h, v13.h[4]\n"
5681 // At this point, v24 -- v31 aren't used anymore for the current block,
5682 // so we can start clearing these accumulators for the next block
5683 // (next iteration of the main loop).
5684 RUY_MAKE_ZERO(v24)
5685 RUY_MAKE_ZERO(v25)
5686 RUY_MAKE_ZERO(v26)
5687 RUY_MAKE_ZERO(v27)
5688 RUY_MAKE_ZERO(v28)
5689 RUY_MAKE_ZERO(v29)
5690 RUY_MAKE_ZERO(v30)
5691 RUY_MAKE_ZERO(v31)
5692
5693 // Add the destination zero point
5694 "add v16.8h, v16.8h, v14.8h\n"
5695 "add v17.8h, v17.8h, v14.8h\n"
5696 "add v18.8h, v18.8h, v14.8h\n"
5697 "add v19.8h, v19.8h, v14.8h\n"
5698 "add v20.8h, v20.8h, v14.8h\n"
5699 "add v21.8h, v21.8h, v14.8h\n"
5700 "add v22.8h, v22.8h, v14.8h\n"
5701 "add v23.8h, v23.8h, v14.8h\n"
5702
5703 // Load the clamp_min, clamp_max bounds
5704 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
5705 // Cast-and-saturate from int16 to uint8
5706 "sqxtun v16.8b, v16.8h\n"
5707 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
5708 "sqxtun2 v16.16b, v17.8h\n"
5709 "sqxtun v17.8b, v18.8h\n"
5710 "sqxtun2 v17.16b, v19.8h\n"
5711 "sqxtun v18.8b, v20.8h\n"
5712 "sqxtun2 v18.16b, v21.8h\n"
5713 "sqxtun v19.8b, v22.8h\n"
5714 "sqxtun2 v19.16b, v23.8h\n"
5715
5716 "dup v14.16b, w2\n" // clamp_min
5717 "dup v15.16b, w3\n" // clamp_max
5718
5719 // Compute how much of the 8x8 block of destination 8bit values that
5720 // we have computed, fit in the destination matrix. Typically, all of
5721 // it fits, but when the destination matrix shape is not a multiple
5722 // of 8x8, there are some 8x8 blocks along the boundaries that do
5723 // not fit entirely.
5724 "sub w1, %w[dst_rows], %w[row]\n"
5725 // Apply the clamp_min bound
5726 "umax v16.16b, v16.16b, v14.16b\n"
5727 "sub w2, %w[dst_cols], %w[col]\n"
5728 "umax v17.16b, v17.16b, v14.16b\n"
5729 "mov w3, #8\n"
5730 "umax v18.16b, v18.16b, v14.16b\n"
5731 "cmp w1, #8\n"
5732 "umax v19.16b, v19.16b, v14.16b\n"
5733 // Compute w1 = how many rows of the 8x8 block fit
5734 "csel w1, w1, w3, le\n"
5735 // Apply the clamp_max bound
5736 "umin v16.16b, v16.16b, v15.16b\n"
5737 "cmp w2, #8\n"
5738 "umin v17.16b, v17.16b, v15.16b\n"
5739 // Compute w2 = how many cols of the 8x8 block fit
5740 "csel w2, w2, w3, le\n"
5741 "umin v18.16b, v18.16b, v15.16b\n"
5742 "umin v19.16b, v19.16b, v15.16b\n"
5743
5744 // Make it so that all of the final 8bit values are stored in the
5745 // first 64bits of 128bit NEON registers, so they can be stored
5746 // by 64bit st1 store instructions with byte alignment.
5747 "dup d20, v16.d[1]\n"
5748 "dup d21, v17.d[1]\n"
5749 "dup d22, v18.d[1]\n"
5750 "dup d23, v19.d[1]\n"
5751
5752 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
5753 "cmp w1, w3\n"
5754 "ccmp w2, w3, 0, eq\n"
5755 // Yes, all of the 8x8 block fits, go to fast path.
5756 "beq 30f\n"
5757 // Not all of the 8x8 block fits.
5758 // Set (x3 address, x4 stride) to write to dst_tmp_buf
5759 "mov x3, %[dst_tmp_buf]\n"
5760 "mov x4, #8\n"
5761 "b 31f\n"
5762 "30:\n"
5763 // Yes, all of the 8x8 block fits.
5764 // Set (x3 address, x4 stride) to write directly to destination matrix.
5765 "mov x3, %[dst_ptr]\n"
5766 "mov x4, x11\n"
5767 "31:\n"
5768
5769 // Write our 8bit values to the destination described by
5770 // (x3 address, x4 stride).
5771 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5772 "st1 {v16.8b}, [x3], x4\n"
5773 RUY_MAKE_ZERO(v16)
5774 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5775 "st1 {v20.8b}, [x3], x4\n"
5776 RUY_MAKE_ZERO(v20)
5777 // For the next block: perform the first few multiply-adds on the data
5778 // that we have already loaded.
5779 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5780 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5781 "st1 {v17.8b}, [x3], x4\n"
5782 RUY_MAKE_ZERO(v17)
5783 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5784 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5785 "st1 {v21.8b}, [x3], x4\n"
5786 RUY_MAKE_ZERO(v21)
5787 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5788 "st1 {v18.8b}, [x3], x4\n"
5789 RUY_MAKE_ZERO(v18)
5790 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5791 "st1 {v22.8b}, [x3], x4\n"
5792 RUY_MAKE_ZERO(v22)
5793 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5794 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5795 "st1 {v19.8b}, [x3], x4\n"
5796 RUY_MAKE_ZERO(v19)
5797 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5798 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5799 "st1 {v23.8b}, [x3], x4\n"
5800 RUY_MAKE_ZERO(v23)
5801
5802 // If all of the 8x8 block fits, we just finished writing it to the
5803 // destination, so we skip the next part.
5804 "beq 41f\n"
5805 // Not all of the 8x8 block fits in the destination matrix. We just
5806 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
5807 // it to copy into the destination matrix the part that fits.
5808 "mov x3, %[dst_tmp_buf]\n"
5809 "mov x4, %[dst_ptr]\n"
5810 "mov w6, #0\n"
5811 "50:\n"
5812 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
5813 "mov w5, #0\n"
5814 "51:\n"
5815 "ldrb w7, [x3, w5, uxtw]\n"
5816 "strb w7, [x4, w5, uxtw]\n"
5817 "add w5, w5, #1\n"
5818 "cmp w5, w1\n"
5819 "blt 51b\n"
5820 "add w6, w6, #1\n"
5821 "add x3, x3, #8\n"
5822 "add x4, x4, x11\n"
5823 "cmp w6, w2\n"
5824 "blt 50b\n"
5825 "41:\n"
5826 "add %[dst_ptr], %[dst_ptr], #8\n"
5827
5828 // At this point we have completely finished writing values to the
5829 // destination matrix for the current block.
5830
5831 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
5832
5833 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
5834
5835 // Cast-and-saturate from int32 to int16
5836 "sqxtn v16.4h, v16.4s\n"
5837 "sqxtn2 v16.8h, v17.4s\n"
5838 "sqxtn v17.4h, v18.4s\n"
5839 "sqxtn2 v17.8h, v19.4s\n"
5840 "sqxtn v18.4h, v20.4s\n"
5841 "sqxtn2 v18.8h, v21.4s\n"
5842 "sqxtn v19.4h, v22.4s\n"
5843 "sqxtn2 v19.8h, v23.4s\n"
5844 "sqxtn v20.4h, v24.4s\n"
5845 "sqxtn2 v20.8h, v25.4s\n"
5846 "sqxtn v21.4h, v26.4s\n"
5847 "sqxtn2 v21.8h, v27.4s\n"
5848 "sqxtn v22.4h, v28.4s\n"
5849 "sqxtn2 v22.8h, v29.4s\n"
5850 "sqxtn v23.4h, v30.4s\n"
5851 "sqxtn2 v23.8h, v31.4s\n"
5852
5853 // Destination zero_point
5854 "dup v14.8h, v13.h[4]\n"
5855 // At this point, v24 -- v31 aren't used anymore for the current block,
5856 // so we can start clearing these accumulators for the next block
5857 // (next iteration of the main loop).
5858 RUY_MAKE_ZERO(v24)
5859 RUY_MAKE_ZERO(v25)
5860 RUY_MAKE_ZERO(v26)
5861 RUY_MAKE_ZERO(v27)
5862 RUY_MAKE_ZERO(v28)
5863 RUY_MAKE_ZERO(v29)
5864 RUY_MAKE_ZERO(v30)
5865 RUY_MAKE_ZERO(v31)
5866
5867 // Add the destination zero point
5868 "add v16.8h, v16.8h, v14.8h\n"
5869 "add v17.8h, v17.8h, v14.8h\n"
5870 "add v18.8h, v18.8h, v14.8h\n"
5871 "add v19.8h, v19.8h, v14.8h\n"
5872 "add v20.8h, v20.8h, v14.8h\n"
5873 "add v21.8h, v21.8h, v14.8h\n"
5874 "add v22.8h, v22.8h, v14.8h\n"
5875 "add v23.8h, v23.8h, v14.8h\n"
5876
5877 // Load the clamp_min, clamp_max bounds
5878 "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
5879 // Cast-and-saturate from int16 to uint8
5880 "sqxtn v16.8b, v16.8h\n"
5881 "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
5882 "sqxtn2 v16.16b, v17.8h\n"
5883 "sqxtn v17.8b, v18.8h\n"
5884 "sqxtn2 v17.16b, v19.8h\n"
5885 "sqxtn v18.8b, v20.8h\n"
5886 "sqxtn2 v18.16b, v21.8h\n"
5887 "sqxtn v19.8b, v22.8h\n"
5888 "sqxtn2 v19.16b, v23.8h\n"
5889
5890 "dup v14.16b, w2\n" // clamp_min
5891 "dup v15.16b, w3\n" // clamp_max
5892
5893 // Compute how much of the 8x8 block of destination 8bit values that
5894 // we have computed, fit in the destination matrix. Typically, all of
5895 // it fits, but when the destination matrix shape is not a multiple
5896 // of 8x8, there are some 8x8 blocks along the boundaries that do
5897 // not fit entirely.
5898 "sub w1, %w[dst_rows], %w[row]\n"
5899 // Apply the clamp_min bound
5900 "smax v16.16b, v16.16b, v14.16b\n"
5901 "sub w2, %w[dst_cols], %w[col]\n"
5902 "smax v17.16b, v17.16b, v14.16b\n"
5903 "mov w3, #8\n"
5904 "smax v18.16b, v18.16b, v14.16b\n"
5905 "cmp w1, #8\n"
5906 "smax v19.16b, v19.16b, v14.16b\n"
5907 // Compute w1 = how many rows of the 8x8 block fit
5908 "csel w1, w1, w3, le\n"
5909 // Apply the clamp_max bound
5910 "smin v16.16b, v16.16b, v15.16b\n"
5911 "cmp w2, #8\n"
5912 "smin v17.16b, v17.16b, v15.16b\n"
5913 // Compute w2 = how many cols of the 8x8 block fit
5914 "csel w2, w2, w3, le\n"
5915 "smin v18.16b, v18.16b, v15.16b\n"
5916 "smin v19.16b, v19.16b, v15.16b\n"
5917
5918 // Make it so that all of the final 8bit values are stored in the
5919 // first 64bits of 128bit NEON registers, so they can be stored
5920 // by 64bit st1 store instructions with byte alignment.
5921 "dup d20, v16.d[1]\n"
5922 "dup d21, v17.d[1]\n"
5923 "dup d22, v18.d[1]\n"
5924 "dup d23, v19.d[1]\n"
5925
5926 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
5927 "cmp w1, w3\n"
5928 "ccmp w2, w3, 0, eq\n"
5929 // Yes, all of the 8x8 block fits, go to fast path.
5930 "beq 130f\n"
5931 // Not all of the 8x8 block fits.
5932 // Set (x3 address, x4 stride) to write to dst_tmp_buf
5933 "mov x3, %[dst_tmp_buf]\n"
5934 "mov x4, #8\n"
5935 "b 131f\n"
5936 "130:\n"
5937 // Yes, all of the 8x8 block fits.
5938 // Set (x3 address, x4 stride) to write directly to destination matrix.
5939 "mov x3, %[dst_ptr]\n"
5940 "mov x4, x11\n"
5941 "131:\n"
5942
5943 // Write our 8bit values to the destination described by
5944 // (x3 address, x4 stride).
5945 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5946 "st1 {v16.8b}, [x3], x4\n"
5947 RUY_MAKE_ZERO(v16)
5948 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5949 "st1 {v20.8b}, [x3], x4\n"
5950 RUY_MAKE_ZERO(v20)
5951 // For the next block: perform the first few multiply-adds on the data
5952 // that we have already loaded.
5953 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
5954 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5955 "st1 {v17.8b}, [x3], x4\n"
5956 RUY_MAKE_ZERO(v17)
5957 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
5958 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5959 "st1 {v21.8b}, [x3], x4\n"
5960 RUY_MAKE_ZERO(v21)
5961 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5962 "st1 {v18.8b}, [x3], x4\n"
5963 RUY_MAKE_ZERO(v18)
5964 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5965 "st1 {v22.8b}, [x3], x4\n"
5966 RUY_MAKE_ZERO(v22)
5967 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
5968 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5969 "st1 {v19.8b}, [x3], x4\n"
5970 RUY_MAKE_ZERO(v19)
5971 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
5972 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
5973 "st1 {v23.8b}, [x3], x4\n"
5974 RUY_MAKE_ZERO(v23)
5975
5976 // If all of the 8x8 block fits, we just finished writing it to the
5977 // destination, so we skip the next part.
5978 "beq 141f\n"
5979 // Not all of the 8x8 block fits in the destination matrix. We just
5980 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
5981 // it to copy into the destination matrix the part that fits.
5982 "mov x3, %[dst_tmp_buf]\n"
5983 "mov x4, %[dst_ptr]\n"
5984 "mov w6, #0\n"
5985 "150:\n"
5986 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
5987 "mov w5, #0\n"
5988 "151:\n"
5989 "ldrb w7, [x3, w5, uxtw]\n"
5990 "strb w7, [x4, w5, uxtw]\n"
5991 "add w5, w5, #1\n"
5992 "cmp w5, w1\n"
5993 "blt 151b\n"
5994 "add w6, w6, #1\n"
5995 "add x3, x3, #8\n"
5996 "add x4, x4, x11\n"
5997 "cmp w6, w2\n"
5998 "blt 150b\n"
5999 "141:\n"
6000 "add %[dst_ptr], %[dst_ptr], #8\n"
6001
6002 // At this point we have completely finished writing values to the
6003 // destination matrix for the current block.
6004
6005 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
6006
6007 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
6008
6009 // Add the destination zero point
6010 "dup v14.8h, v13.h[4]\n"
6011 "saddw v16.4s, v16.4s, v14.4h\n"
6012 "saddw v17.4s, v17.4s, v14.4h\n"
6013 "saddw v18.4s, v18.4s, v14.4h\n"
6014 "saddw v19.4s, v19.4s, v14.4h\n"
6015 "saddw v20.4s, v20.4s, v14.4h\n"
6016 "saddw v21.4s, v21.4s, v14.4h\n"
6017 "saddw v22.4s, v22.4s, v14.4h\n"
6018 "saddw v23.4s, v23.4s, v14.4h\n"
6019 "saddw v24.4s, v24.4s, v14.4h\n"
6020 "saddw v25.4s, v25.4s, v14.4h\n"
6021 "saddw v26.4s, v26.4s, v14.4h\n"
6022 "saddw v27.4s, v27.4s, v14.4h\n"
6023 "saddw v28.4s, v28.4s, v14.4h\n"
6024 "saddw v29.4s, v29.4s, v14.4h\n"
6025 "saddw v30.4s, v30.4s, v14.4h\n"
6026 "saddw v31.4s, v31.4s, v14.4h\n"
6027
6028 // Cast-and-saturate from int32 to int16
6029 "sqxtn v16.4h, v16.4s\n"
6030 "sqxtn2 v16.8h, v17.4s\n"
6031 "sqxtn v17.4h, v18.4s\n"
6032 "sqxtn2 v17.8h, v19.4s\n"
6033 "sqxtn v18.4h, v20.4s\n"
6034 "sqxtn2 v18.8h, v21.4s\n"
6035 "sqxtn v19.4h, v22.4s\n"
6036 "sqxtn2 v19.8h, v23.4s\n"
6037 "sqxtn v20.4h, v24.4s\n"
6038 "sqxtn2 v20.8h, v25.4s\n"
6039 "sqxtn v21.4h, v26.4s\n"
6040 "sqxtn2 v21.8h, v27.4s\n"
6041 "sqxtn v22.4h, v28.4s\n"
6042 "sqxtn2 v22.8h, v29.4s\n"
6043 "sqxtn v23.4h, v30.4s\n"
6044 "sqxtn2 v23.8h, v31.4s\n"
6045
6046 // At this point, v24 -- v31 aren't used anymore for the current block,
6047 // so we can start clearing these accumulators for the next block
6048 // (next iteration of the main loop).
6049 RUY_MAKE_ZERO(v24)
6050 RUY_MAKE_ZERO(v25)
6051 RUY_MAKE_ZERO(v26)
6052 RUY_MAKE_ZERO(v27)
6053 RUY_MAKE_ZERO(v28)
6054 RUY_MAKE_ZERO(v29)
6055 RUY_MAKE_ZERO(v30)
6056 RUY_MAKE_ZERO(v31)
6057
6058 // Load the clamp_min, clamp_max bounds
6059 "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
6060 "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
6061 "dup v14.8h, w2\n" // clamp_min
6062 "dup v15.8h, w3\n" // clamp_max
6063
6064 // Apply the clamp_min bound
6065 "smax v16.8h, v16.8h, v14.8h\n"
6066 "smax v17.8h, v17.8h, v14.8h\n"
6067 "smax v18.8h, v18.8h, v14.8h\n"
6068 "smax v19.8h, v19.8h, v14.8h\n"
6069 "smax v20.8h, v20.8h, v14.8h\n"
6070 "smax v21.8h, v21.8h, v14.8h\n"
6071 "smax v22.8h, v22.8h, v14.8h\n"
6072 "smax v23.8h, v23.8h, v14.8h\n"
6073 // Apply the clamp_max bound
6074 "smin v16.8h, v16.8h, v15.8h\n"
6075 "smin v17.8h, v17.8h, v15.8h\n"
6076 "smin v18.8h, v18.8h, v15.8h\n"
6077 "smin v19.8h, v19.8h, v15.8h\n"
6078 "smin v20.8h, v20.8h, v15.8h\n"
6079 "smin v21.8h, v21.8h, v15.8h\n"
6080 "smin v22.8h, v22.8h, v15.8h\n"
6081 "smin v23.8h, v23.8h, v15.8h\n"
6082
6083 // Compute how much of the 8x8 block of destination 16bit values that
6084 // we have computed, fit in the destination matrix. Typically, all of
6085 // it fits, but when the destination matrix shape is not a multiple
6086 // of 8x8, there are some 8x8 blocks along the boundaries that do
6087 // not fit entirely.
6088 "sub w1, %w[dst_rows], %w[row]\n"
6089 "sub w2, %w[dst_cols], %w[col]\n"
6090 "mov w3, #8\n"
6091 "cmp w1, #8\n"
6092 // Compute w1 = how many rows of the 8x8 block fit
6093 "csel w1, w1, w3, le\n"
6094 "cmp w2, #8\n"
6095 // Compute w1 = how many rows of the 8x8 block fit
6096 "csel w2, w2, w3, le\n"
6097
6098 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
6099 "cmp w1, w3\n"
6100 "ccmp w2, w3, 0, eq\n"
6101 // Yes, all of the 8x8 block fits, go to fast path.
6102 "beq 230f\n"
6103 // Not all of the 8x8 block fits.
6104 // Set (x3 address, x4 stride) to write to dst_tmp_buf
6105 "mov x3, %[dst_tmp_buf]\n"
6106 "mov x4, #16\n"
6107 "b 231f\n"
6108 "230:\n"
6109 // Yes, all of the 8x8 block fits.
6110 // Set (x3 address, x4 stride) to write directly to destination matrix.
6111 "mov x3, %[dst_ptr]\n"
6112 "mov x4, x11\n"
6113 "231:\n"
6114
6115 // Write our 8bit values to the destination described by
6116 // (x3 address, x4 stride).
6117 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6118 "st1 {v16.8h}, [x3], x4\n"
6119 RUY_MAKE_ZERO(v16)
6120 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6121 "st1 {v17.8h}, [x3], x4\n"
6122 RUY_MAKE_ZERO(v17)
6123 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6124 "st1 {v18.8h}, [x3], x4\n"
6125 RUY_MAKE_ZERO(v18)
6126 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6127 "st1 {v19.8h}, [x3], x4\n"
6128 RUY_MAKE_ZERO(v19)
6129 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6130 "st1 {v20.8h}, [x3], x4\n"
6131 RUY_MAKE_ZERO(v20)
6132 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6133 "st1 {v21.8h}, [x3], x4\n"
6134 RUY_MAKE_ZERO(v21)
6135 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6136 "st1 {v22.8h}, [x3], x4\n"
6137 RUY_MAKE_ZERO(v22)
6138 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6139 "st1 {v23.8h}, [x3], x4\n"
6140 RUY_MAKE_ZERO(v23)
6141
6142 // For the next block: perform the first few multiply-adds on the data
6143 // that we have already loaded.
6144 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
6145 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
6146 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
6147 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
6148
6149 // If all of the 8x8 block fits, we just finished writing it to the
6150 // destination, so we skip the next part.
6151 "beq 241f\n"
6152 // Not all of the 8x8 block fits in the destination matrix. We just
6153 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
6154 // it to copy into the destination matrix the part that fits.
6155 "mov x3, %[dst_tmp_buf]\n"
6156 "mov x4, %[dst_ptr]\n"
6157 "mov w6, #0\n"
6158 "250:\n"
6159 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6160 "mov w5, #0\n"
6161 "251:\n"
6162 "ldrsh w7, [x3, x5, lsl #1]\n"
6163 "strh w7, [x4, x5, lsl #1]\n"
6164 "add w5, w5, #1\n"
6165 "cmp w5, w1\n"
6166 "blt 251b\n"
6167 "add w6, w6, #1\n"
6168 "add x3, x3, #16\n"
6169 "add x4, x4, x11\n"
6170 "cmp w6, w2\n"
6171 "blt 250b\n"
6172 "241:\n"
6173 "add %[dst_ptr], %[dst_ptr], #16\n"
6174 // At this point we have completely finished writing values to the
6175 // destination matrix for the current block.
6176
6177 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
6178
6179 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
6180
6181 "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
6182 "ldr x1, [%[lhs_ptr]], #8\n"
6183 "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
6184 "ldr x2, [%[lhs_ptr]], #8\n"
6185 "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
6186 "ldr x5, [%[rhs_ptr]], #8\n"
6187 "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
6188 "ldr x6, [%[rhs_ptr]], #8\n"
6189 "ins v0.d[1], x1\n"
6190 "ins v1.d[1], x2\n"
6191 "ins v2.d[1], x5\n"
6192 "ins v3.d[1], x6\n"
6193
6194 // Since the store type is the same as the accum type, no need for
6195 // downcast. There's also no need for clamp by min/max.
6196
6197 // Compute how much of the 8x8 block of destination 32it values that
6198 // we have computed, fit in the destination matrix. Typically, all of
6199 // it fits, but when the destination matrix shape is not a multiple
6200 // of 8x8, there are some 8x8 blocks along the boundaries that do
6201 // not fit entirely.
6202 "sub w1, %w[dst_rows], %w[row]\n"
6203 "sub w2, %w[dst_cols], %w[col]\n"
6204 "mov w3, #8\n"
6205 "cmp w1, #8\n"
6206 // Compute w1 = how many rows of the 8x8 block fit
6207 "csel w1, w1, w3, le\n"
6208 "cmp w2, #8\n"
6209 // Compute w1 = how many rows of the 8x8 block fit
6210 "csel w2, w2, w3, le\n"
6211
6212 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
6213 "cmp w1, w3\n"
6214 "ccmp w2, w3, 0, eq\n"
6215 // Yes, all of the 8x8 block fits, go to fast path.
6216 "beq 330f\n"
6217 // Not all of the 8x8 block fits.
6218 // Write to dst_tmp_buf
6219 "mov x3, %[dst_tmp_buf]\n"
6220 "st1 {v16.4s}, [x3], #16\n"
6221 RUY_MAKE_ZERO(v16)
6222 "st1 {v17.4s}, [x3], #16\n"
6223 RUY_MAKE_ZERO(v17)
6224 "st1 {v18.4s}, [x3], #16\n"
6225 RUY_MAKE_ZERO(v18)
6226 "st1 {v19.4s}, [x3], #16\n"
6227 RUY_MAKE_ZERO(v19)
6228 "st1 {v20.4s}, [x3], #16\n"
6229 RUY_MAKE_ZERO(v20)
6230 "st1 {v21.4s}, [x3], #16\n"
6231 RUY_MAKE_ZERO(v21)
6232 "st1 {v22.4s}, [x3], #16\n"
6233 RUY_MAKE_ZERO(v22)
6234 "st1 {v23.4s}, [x3], #16\n"
6235 RUY_MAKE_ZERO(v23)
6236 "st1 {v24.4s}, [x3], #16\n"
6237 RUY_MAKE_ZERO(v24)
6238 "st1 {v25.4s}, [x3], #16\n"
6239 RUY_MAKE_ZERO(v25)
6240 "st1 {v26.4s}, [x3], #16\n"
6241 RUY_MAKE_ZERO(v26)
6242 "st1 {v27.4s}, [x3], #16\n"
6243 RUY_MAKE_ZERO(v27)
6244 "st1 {v28.4s}, [x3], #16\n"
6245 RUY_MAKE_ZERO(v28)
6246 "st1 {v29.4s}, [x3], #16\n"
6247 RUY_MAKE_ZERO(v29)
6248 "st1 {v30.4s}, [x3], #16\n"
6249 RUY_MAKE_ZERO(v30)
6250 "st1 {v31.4s}, [x3], #16\n"
6251 RUY_MAKE_ZERO(v31)
6252
6253 "b 331f\n"
6254
6255 "330:\n"
6256 // Yes, all of the 8x8 block fits.
6257 "mov x4, %[dst_ptr]\n"
6258 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6259 "st1 {v16.4s, v17.4s}, [x4], x11\n"
6260 RUY_MAKE_ZERO(v16)
6261 RUY_MAKE_ZERO(v17)
6262 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6263 "st1 {v18.4s, v19.4s}, [x4], x11\n"
6264 RUY_MAKE_ZERO(v18)
6265 RUY_MAKE_ZERO(v19)
6266 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6267 "st1 {v20.4s, v21.4s}, [x4], x11\n"
6268 RUY_MAKE_ZERO(v20)
6269 RUY_MAKE_ZERO(v21)
6270 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6271 "st1 {v22.4s, v23.4s}, [x4], x11\n"
6272 RUY_MAKE_ZERO(v22)
6273 RUY_MAKE_ZERO(v23)
6274 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6275 "st1 {v24.4s, v25.4s}, [x4], x11\n"
6276 RUY_MAKE_ZERO(v24)
6277 RUY_MAKE_ZERO(v25)
6278 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6279 "st1 {v26.4s, v27.4s}, [x4], x11\n"
6280 RUY_MAKE_ZERO(v26)
6281 RUY_MAKE_ZERO(v27)
6282 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6283 "st1 {v28.4s, v29.4s}, [x4], x11\n"
6284 RUY_MAKE_ZERO(v28)
6285 RUY_MAKE_ZERO(v29)
6286 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6287 "st1 {v30.4s, v31.4s}, [x4], x11\n"
6288 RUY_MAKE_ZERO(v30)
6289 RUY_MAKE_ZERO(v31)
6290
6291 "331:\n"
6292
6293 // For the next block: perform the first few multiply-adds on the data
6294 // that we have already loaded.
6295 ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
6296 ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
6297 ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
6298 ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
6299
6300 // If all of the 8x8 block fits, we just finished writing it to the
6301 // destination, so we skip the next part.
6302 "beq 341f\n"
6303
6304 // Not all of the 8x8 block fits in the destination matrix. We just
6305 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
6306 // it to copy into the destination matrix the part that fits.
6307 "mov x3, %[dst_tmp_buf]\n"
6308 "mov x4, %[dst_ptr]\n"
6309 "mov w6, #0\n"
6310 "350:\n"
6311 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
6312 "mov w5, #0\n"
6313 "351:\n"
6314 "ldr w7, [x3, x5, lsl #2]\n"
6315 "str w7, [x4, x5, lsl #2]\n"
6316 "add w5, w5, #1\n"
6317 "cmp w5, w1\n"
6318 "blt 351b\n"
6319 "add w6, w6, #1\n"
6320 "add x3, x3, #32\n"
6321 "add x4, x4, x11\n"
6322 "cmp w6, w2\n"
6323 "blt 350b\n"
6324 "341:\n"
6325 "add %[dst_ptr], %[dst_ptr], #32\n"
6326 // At this point we have completely finished writing values to the
6327 // destination matrix for the current block.
6328
6329 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
6330
6331 // Reload some params --- we had used x5 -- x7 for a few other things
6332 // since the last time we had loaded them.
6333 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6334 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6335
6336 // Move to the next block of the destination matrix, for the next iter
6337 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
6338 // been updated earlier.
6339 // Have we reached the end row?
6340 "cmp %w[row], w7\n"
6341 "beq 20f\n" // yes, end row.
6342 // Not end row. Move to the next row.
6343 "add %w[row], %w[row], #8\n"
6344 "b 21f\n"
6345 "20:\n"
6346 // Was already at end row.
6347 "mov %w[row], w6\n" // Move back to first row.
6348 "add %w[col], %w[col], #8\n" // Move to the next column.
6349 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
6350 "mov %[dst_ptr], %[dst_col_ptr]\n"
6351 "21:\n"
6352
6353 // Main loop exit condition: have we hit the end column?
6354 "cmp %w[col], w8\n"
6355 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6356 "ble 1b\n"
6357
6358 // clang-format on
6359
6360 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
6361 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
6362 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
6363 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
6364 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
6365 [dst_type_id] "r"(params.dst_type_id)
6366 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
6367 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
6368 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
6369 "v26", "v27", "v28", "v29", "v30", "v31");
6370 }
6371 #undef RUY_OFFSET_BIAS
6372 #undef RUY_OFFSET_LHS_SUMS
6373 #undef RUY_OFFSET_RHS_SUMS
6374 #undef RUY_OFFSET_LHS_BASE_PTR
6375 #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
6376 #undef RUY_OFFSET_MULTIPLIER_EXPONENT
6377 #undef RUY_OFFSET_RHS_BASE_PTR
6378 #undef RUY_OFFSET_DST_BASE_PTR
6379 #undef RUY_OFFSET_LHS_ZERO_POINT
6380 #undef RUY_OFFSET_RHS_ZERO_POINT
6381 #undef RUY_OFFSET_DST_ZERO_POINT
6382 #undef RUY_OFFSET_PROD_ZP_DEPTH
6383 #undef RUY_OFFSET_START_ROW
6384 #undef RUY_OFFSET_START_COL
6385 #undef RUY_OFFSET_LAST_ROW
6386 #undef RUY_OFFSET_LAST_COL
6387 #undef RUY_OFFSET_DST_ROWS
6388 #undef RUY_OFFSET_DST_COLS
6389 #undef RUY_OFFSET_LHS_STRIDE
6390 #undef RUY_OFFSET_RHS_STRIDE
6391 #undef RUY_OFFSET_DST_STRIDE
6392 #undef RUY_OFFSET_DEPTH
6393 #undef RUY_OFFSET_CLAMP_MIN
6394 #undef RUY_OFFSET_CLAMP_MAX
6395 #undef RUY_OFFSET_FLAGS
6396
6397 #define RUY_OFFSET_LHS_BASE_PTR 0
6398 #define RUY_OFFSET_RHS_BASE_PTR 8
6399 #define RUY_OFFSET_DST_BASE_PTR 16
6400 #define RUY_OFFSET_BIAS 24
6401 #define RUY_OFFSET_START_ROW 32
6402 #define RUY_OFFSET_START_COL 36
6403 #define RUY_OFFSET_LAST_ROW 40
6404 #define RUY_OFFSET_LAST_COL 44
6405 #define RUY_OFFSET_LHS_STRIDE 56
6406 #define RUY_OFFSET_RHS_STRIDE 60
6407 #define RUY_OFFSET_DST_STRIDE 64
6408 #define RUY_OFFSET_DEPTH 68
6409 #define RUY_OFFSET_CLAMP_MIN 72
6410 #define RUY_OFFSET_CLAMP_MAX 76
6411 #define RUY_OFFSET_FLAGS 80
6412
6413 template <typename Params>
CheckOffsetsInKernelParamsFloat(const Params &)6414 void CheckOffsetsInKernelParamsFloat(const Params&) {
6415 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
6416 static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
6417 static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
6418 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
6419 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
6420 static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
6421 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
6422 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
6423 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
6424 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
6425 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
6426 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
6427 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
6428 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
6429 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
6430 }
6431
6432 // Just a plain float kernel; good enough for out-of-order cores.
6433 // The closest to it in the gemmlowp collection would be
6434 // NEON_64bit_GEMM_Float32_WithScalar,
6435 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925
6436 //
6437 // Besides ruy-ification, the main nuance here is that we stick to a 8x8
6438 // width instead of the wider 12x8 that the register space permits and that
6439 // the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now
6440 // and we don't have evidence that going beyond 8x8 is needed.
KernelFloatNeon(const KernelParamsFloat<8,8> & params)6441 void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) {
6442 CheckOffsetsInKernelParamsFloat(params);
6443 profiler::ScopeLabel label("Kernel (kNeon)");
6444
6445 const float* lhs_col_ptr = params.lhs_base_ptr;
6446 const float* rhs_col_ptr = params.rhs_base_ptr;
6447 const float* lhs_ptr = lhs_col_ptr;
6448 const float* rhs_ptr = rhs_col_ptr;
6449 float* dst_col_ptr = params.dst_base_ptr;
6450 float* dst_ptr = dst_col_ptr;
6451 int row = params.start_row;
6452 int col = params.start_col;
6453
6454 // The asm kernel below has the following NEON register allocation:
6455 //
6456 // v16 -- v31 are accumulators.
6457 // During accumulation, v0 -- v15 are used to load data from LHS and RHS.
6458 // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
6459 // v3 are used to load a 1x8 block of RHS, like this:
6460 //
6461 // RHS 1x8 block
6462 // /-----------------------------------------|
6463 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
6464 // \-----------------------------------------/
6465 // LHS 8x1 block
6466 // /---------------------\ /-----------------------------------------|
6467 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
6468 // | ... | | ... ... |
6469 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
6470 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
6471 // | ... | | ... ... |
6472 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
6473 // \---------------------/ \-----------------------------------------/
6474 // accumulators 8x8 block
6475 //
6476 // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
6477 // is repeated 4 times, using 4x more registers for LHS and RHS, so that
6478 // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
6479 //
6480 // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
6481 // unused, and v8 -- v15 are used for floading parameters used for the
6482 // post-accumulation part of the kernel.
6483 asm volatile(
6484 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
6485
6486 // clang-format off
6487
6488 // Load some parameters into registers.
6489 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
6490 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
6491 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
6492 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
6493 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
6494 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
6495 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
6496 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
6497
6498 // Load the first 32 bytes of LHS and RHS data.
6499 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6500 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6501 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
6502 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6503
6504 // Clear accumulators.
6505 RUY_MAKE_ZERO(v16)
6506 RUY_MAKE_ZERO(v17)
6507 RUY_MAKE_ZERO(v18)
6508 RUY_MAKE_ZERO(v19)
6509 RUY_MAKE_ZERO(v20)
6510 RUY_MAKE_ZERO(v21)
6511 RUY_MAKE_ZERO(v22)
6512 RUY_MAKE_ZERO(v23)
6513 RUY_MAKE_ZERO(v24)
6514 RUY_MAKE_ZERO(v25)
6515 RUY_MAKE_ZERO(v26)
6516 RUY_MAKE_ZERO(v27)
6517 RUY_MAKE_ZERO(v28)
6518 RUY_MAKE_ZERO(v29)
6519 RUY_MAKE_ZERO(v30)
6520 RUY_MAKE_ZERO(v31)
6521
6522 // w1 is the number of levels of depth that we have already loaded
6523 // LHS and RHS data for. Corresponding to the initial ld1 instructions
6524 // above, this is currently 1.
6525 "mov w1, #1\n"
6526
6527 // Main loop of the whole GEMM, over rows and columns of the
6528 // destination matrix.
6529 "1:\n"
6530
6531 "fmla v16.4s, v0.4s, v2.s[0]\n"
6532 "fmla v18.4s, v0.4s, v2.s[1]\n"
6533 "fmla v20.4s, v0.4s, v2.s[2]\n"
6534 "fmla v22.4s, v0.4s, v2.s[3]\n"
6535
6536 #if RUY_OPT(MAX_STREAMING)
6537 "cmp w12, #8\n"
6538 "blt 78f\n"
6539 "and w2, w12, #-4\n"
6540
6541 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
6542 "ld1 {v5.4s}, [%[lhs_ptr]], #16\n"
6543 "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
6544 "ld1 {v7.4s}, [%[rhs_ptr]], #16\n"
6545
6546 "ld1 {v8.4s}, [%[lhs_ptr]], #16\n"
6547 "ld1 {v9.4s}, [%[lhs_ptr]], #16\n"
6548 "ld1 {v10.4s}, [%[rhs_ptr]], #16\n"
6549 "ld1 {v11.4s}, [%[rhs_ptr]], #16\n"
6550
6551 "ld1 {v12.4s}, [%[lhs_ptr]], #16\n"
6552 "ld1 {v13.4s}, [%[lhs_ptr]], #16\n"
6553 "ld1 {v14.4s}, [%[rhs_ptr]], #16\n"
6554 "ld1 {v15.4s}, [%[rhs_ptr]], #16\n"
6555 "mov w1, #4\n"
6556
6557 "80:\n"
6558
6559 "add %[lhs_ptr], %[lhs_ptr], #128\n"
6560 "add %[rhs_ptr], %[rhs_ptr], #128\n"
6561
6562 "fmla v24.4s, v0.4s, v3.s[0]\n"
6563 "fmla v26.4s, v0.4s, v3.s[1]\n"
6564 "fmla v28.4s, v0.4s, v3.s[2]\n"
6565 "fmla v30.4s, v0.4s, v3.s[3]\n"
6566 "ldr q0, [%[lhs_ptr], #-128]\n"
6567 "fmla v25.4s, v1.4s, v3.s[0]\n"
6568 "fmla v27.4s, v1.4s, v3.s[1]\n"
6569 "fmla v29.4s, v1.4s, v3.s[2]\n"
6570 "fmla v31.4s, v1.4s, v3.s[3]\n"
6571 "ldr q3, [%[rhs_ptr], #-112]\n"
6572 "fmla v17.4s, v1.4s, v2.s[0]\n"
6573 "fmla v19.4s, v1.4s, v2.s[1]\n"
6574 "fmla v21.4s, v1.4s, v2.s[2]\n"
6575 "fmla v23.4s, v1.4s, v2.s[3]\n"
6576 "ldr q1, [%[lhs_ptr], #-112]\n"
6577 "fmla v16.4s, v4.4s, v6.s[0]\n"
6578 "fmla v18.4s, v4.4s, v6.s[1]\n"
6579 "ldr q2, [%[rhs_ptr], #-128]\n"
6580 "fmla v20.4s, v4.4s, v6.s[2]\n"
6581 "fmla v22.4s, v4.4s, v6.s[3]\n"
6582
6583 "fmla v24.4s, v4.4s, v7.s[0]\n"
6584 "fmla v26.4s, v4.4s, v7.s[1]\n"
6585 "fmla v28.4s, v4.4s, v7.s[2]\n"
6586 "fmla v30.4s, v4.4s, v7.s[3]\n"
6587 "ldr q4, [%[lhs_ptr], #-96]\n"
6588 "fmla v25.4s, v5.4s, v7.s[0]\n"
6589 "fmla v27.4s, v5.4s, v7.s[1]\n"
6590 "fmla v29.4s, v5.4s, v7.s[2]\n"
6591 "fmla v31.4s, v5.4s, v7.s[3]\n"
6592 "ldr q7, [%[rhs_ptr], #-80]\n"
6593 "fmla v17.4s, v5.4s, v6.s[0]\n"
6594 "fmla v19.4s, v5.4s, v6.s[1]\n"
6595 "fmla v21.4s, v5.4s, v6.s[2]\n"
6596 "fmla v23.4s, v5.4s, v6.s[3]\n"
6597 "ldr q5, [%[lhs_ptr], #-80]\n"
6598 "fmla v16.4s, v8.4s, v10.s[0]\n"
6599 "fmla v18.4s, v8.4s, v10.s[1]\n"
6600 "ldr q6, [%[rhs_ptr], #-96]\n"
6601 "fmla v20.4s, v8.4s, v10.s[2]\n"
6602 "fmla v22.4s, v8.4s, v10.s[3]\n"
6603
6604 "fmla v24.4s, v8.4s, v11.s[0]\n"
6605 "fmla v26.4s, v8.4s, v11.s[1]\n"
6606 "fmla v28.4s, v8.4s, v11.s[2]\n"
6607 "fmla v30.4s, v8.4s, v11.s[3]\n"
6608 "ldr q8, [%[lhs_ptr], #-64]\n"
6609 "fmla v25.4s, v9.4s, v11.s[0]\n"
6610 "fmla v27.4s, v9.4s, v11.s[1]\n"
6611 "fmla v29.4s, v9.4s, v11.s[2]\n"
6612 "fmla v31.4s, v9.4s, v11.s[3]\n"
6613 "ldr q11, [%[rhs_ptr], #-48]\n"
6614 "fmla v17.4s, v9.4s, v10.s[0]\n"
6615 "fmla v19.4s, v9.4s, v10.s[1]\n"
6616 "fmla v21.4s, v9.4s, v10.s[2]\n"
6617 "fmla v23.4s, v9.4s, v10.s[3]\n"
6618 "ldr q9, [%[lhs_ptr], #-48]\n"
6619 "fmla v16.4s, v12.4s, v14.s[0]\n"
6620 "fmla v18.4s, v12.4s, v14.s[1]\n"
6621 "ldr q10, [%[rhs_ptr], #-64]\n"
6622 "fmla v20.4s, v12.4s, v14.s[2]\n"
6623 "fmla v22.4s, v12.4s, v14.s[3]\n"
6624
6625 "fmla v24.4s, v12.4s, v15.s[0]\n"
6626 "fmla v26.4s, v12.4s, v15.s[1]\n"
6627 "fmla v28.4s, v12.4s, v15.s[2]\n"
6628 "fmla v30.4s, v12.4s, v15.s[3]\n"
6629 "ldr q12, [%[lhs_ptr], #-32]\n"
6630 "fmla v25.4s, v13.4s, v15.s[0]\n"
6631 "fmla v27.4s, v13.4s, v15.s[1]\n"
6632 "fmla v29.4s, v13.4s, v15.s[2]\n"
6633 "fmla v31.4s, v13.4s, v15.s[3]\n"
6634 "ldr q15, [%[rhs_ptr], #-16]\n"
6635 "fmla v17.4s, v13.4s, v14.s[0]\n"
6636 "fmla v19.4s, v13.4s, v14.s[1]\n"
6637 "fmla v21.4s, v13.4s, v14.s[2]\n"
6638 "fmla v23.4s, v13.4s, v14.s[3]\n"
6639 "ldr q13, [%[lhs_ptr], #-16]\n"
6640 "fmla v16.4s, v0.4s, v2.s[0]\n"
6641 "fmla v18.4s, v0.4s, v2.s[1]\n"
6642 "ldr q14, [%[rhs_ptr], #-32]\n"
6643 "fmla v20.4s, v0.4s, v2.s[2]\n"
6644 "fmla v22.4s, v0.4s, v2.s[3]\n"
6645
6646 "add w1, w1, #4\n"
6647 "cmp w1, w2\n"
6648 "blt 80b\n"
6649
6650 "fmla v16.4s, v4.4s, v6.s[0]\n"
6651 "fmla v18.4s, v4.4s, v6.s[1]\n"
6652 "fmla v20.4s, v4.4s, v6.s[2]\n"
6653 "fmla v22.4s, v4.4s, v6.s[3]\n"
6654 "fmla v24.4s, v4.4s, v7.s[0]\n"
6655 "fmla v26.4s, v4.4s, v7.s[1]\n"
6656 "fmla v28.4s, v4.4s, v7.s[2]\n"
6657 "fmla v30.4s, v4.4s, v7.s[3]\n"
6658 "fmla v25.4s, v5.4s, v7.s[0]\n"
6659 "fmla v27.4s, v5.4s, v7.s[1]\n"
6660 "fmla v29.4s, v5.4s, v7.s[2]\n"
6661 "fmla v31.4s, v5.4s, v7.s[3]\n"
6662 "fmla v17.4s, v5.4s, v6.s[0]\n"
6663 "fmla v19.4s, v5.4s, v6.s[1]\n"
6664 "fmla v21.4s, v5.4s, v6.s[2]\n"
6665 "fmla v23.4s, v5.4s, v6.s[3]\n"
6666
6667 "fmla v16.4s, v8.4s, v10.s[0]\n"
6668 "fmla v18.4s, v8.4s, v10.s[1]\n"
6669 "fmla v20.4s, v8.4s, v10.s[2]\n"
6670 "fmla v22.4s, v8.4s, v10.s[3]\n"
6671 "fmla v24.4s, v8.4s, v11.s[0]\n"
6672 "fmla v26.4s, v8.4s, v11.s[1]\n"
6673 "fmla v28.4s, v8.4s, v11.s[2]\n"
6674 "fmla v30.4s, v8.4s, v11.s[3]\n"
6675 "fmla v25.4s, v9.4s, v11.s[0]\n"
6676 "fmla v27.4s, v9.4s, v11.s[1]\n"
6677 "fmla v29.4s, v9.4s, v11.s[2]\n"
6678 "fmla v31.4s, v9.4s, v11.s[3]\n"
6679 "fmla v17.4s, v9.4s, v10.s[0]\n"
6680 "fmla v19.4s, v9.4s, v10.s[1]\n"
6681 "fmla v21.4s, v9.4s, v10.s[2]\n"
6682 "fmla v23.4s, v9.4s, v10.s[3]\n"
6683
6684 "fmla v16.4s, v12.4s, v14.s[0]\n"
6685 "fmla v18.4s, v12.4s, v14.s[1]\n"
6686 "fmla v20.4s, v12.4s, v14.s[2]\n"
6687 "fmla v22.4s, v12.4s, v14.s[3]\n"
6688 "fmla v24.4s, v12.4s, v15.s[0]\n"
6689 "fmla v26.4s, v12.4s, v15.s[1]\n"
6690 "fmla v28.4s, v12.4s, v15.s[2]\n"
6691 "fmla v30.4s, v12.4s, v15.s[3]\n"
6692 "fmla v25.4s, v13.4s, v15.s[0]\n"
6693 "fmla v27.4s, v13.4s, v15.s[1]\n"
6694 "fmla v29.4s, v13.4s, v15.s[2]\n"
6695 "fmla v31.4s, v13.4s, v15.s[3]\n"
6696 "fmla v17.4s, v13.4s, v14.s[0]\n"
6697 "fmla v19.4s, v13.4s, v14.s[1]\n"
6698 "fmla v21.4s, v13.4s, v14.s[2]\n"
6699 "fmla v23.4s, v13.4s, v14.s[3]\n"
6700
6701 "78:\n"
6702 #endif
6703
6704 // Accumulation loop
6705 "cmp w1, w12\n"
6706 "beq 79f\n"
6707
6708 "2:\n"
6709 "fmla v24.4s, v0.4s, v3.s[0]\n"
6710 "fmla v26.4s, v0.4s, v3.s[1]\n"
6711 "ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
6712 "fmla v28.4s, v0.4s, v3.s[2]\n"
6713 "fmla v30.4s, v0.4s, v3.s[3]\n"
6714 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6715 "fmla v25.4s, v1.4s, v3.s[0]\n"
6716 "fmla v27.4s, v1.4s, v3.s[1]\n"
6717 "add w1, w1, #1\n"
6718 "fmla v29.4s, v1.4s, v3.s[2]\n"
6719 "fmla v31.4s, v1.4s, v3.s[3]\n"
6720 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6721 "fmla v17.4s, v1.4s, v2.s[0]\n"
6722 "fmla v19.4s, v1.4s, v2.s[1]\n"
6723 "cmp w1, w12\n"
6724 "fmla v21.4s, v1.4s, v2.s[2]\n"
6725 "fmla v23.4s, v1.4s, v2.s[3]\n"
6726 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6727 "fmla v16.4s, v0.4s, v4.s[0]\n"
6728 "fmla v18.4s, v0.4s, v4.s[1]\n"
6729 "mov v2.16b, v4.16b\n"
6730 "fmla v20.4s, v0.4s, v4.s[2]\n"
6731 "fmla v22.4s, v0.4s, v4.s[3]\n"
6732 "blt 2b\n"
6733
6734 "79:\n"
6735
6736 // End of the inner loop on depth. Now perform the remaining
6737 // multiply-adds of the last level of depth, for which the LHS
6738 // and RHS data is already loaded.
6739
6740 "fmla v24.4s, v0.4s, v3.s[0]\n"
6741 "fmla v26.4s, v0.4s, v3.s[1]\n"
6742 "fmla v28.4s, v0.4s, v3.s[2]\n"
6743 "fmla v30.4s, v0.4s, v3.s[3]\n"
6744 "fmla v25.4s, v1.4s, v3.s[0]\n"
6745 "fmla v27.4s, v1.4s, v3.s[1]\n"
6746 "fmla v29.4s, v1.4s, v3.s[2]\n"
6747 "fmla v31.4s, v1.4s, v3.s[3]\n"
6748 "fmla v17.4s, v1.4s, v2.s[0]\n"
6749 "fmla v19.4s, v1.4s, v2.s[1]\n"
6750 "fmla v21.4s, v1.4s, v2.s[2]\n"
6751 "fmla v23.4s, v1.4s, v2.s[3]\n"
6752
6753 // End of accumulation. The registers v16 -- v31 contain the final
6754 // int32 accumulator values of the current 8x8 destination block.
6755 // We now have to compute the final 8-bit values from these int32
6756 // accumulators, and advance to the next 8x8 block. We intertwine
6757 // these two aspects whenever possible for optimal pipelining, both
6758 // at the data flow level (prefetch data for next block as early as
6759 // possible) and instruction pipelining level (some of the next-block
6760 // work can dual-issue with some of the final work on the current
6761 // block).
6762
6763 // Logic to advance to the next block in preparation for the next
6764 // iteration of the main loop. For now, we only want to compute
6765 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
6766 // not yet ready to update the values of row and col, as we still need
6767 // the current values for the rest of the work on the current block.
6768
6769 "cmp %w[row], w7\n" // Have we finished the last row?
6770 "bge 4f\n" // If finished last row, go to 4
6771 // Not finished last row: then advance to next row.
6772 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
6773 "b 5f\n"
6774 "4:\n" // Finished last row...
6775 "mov %[lhs_col_ptr], x5\n" // Go back to first row
6776 // Now we need to advance to the next column. If we already
6777 // finished the last column, then in principle we are done, however
6778 // we can't just return here, as we need to allow the end work of the
6779 // current block to complete. The good news is that at this point it
6780 // doesn't matter what data we load for the next column, since
6781 // we will exit from the main loop below before actually storing
6782 // anything computed from that data.
6783 "cmp %w[col], w8\n" // Have we finished the last column?
6784 "bge 5f\n" // If yes, just carry on without updating the column pointer.
6785 // Not finished last column: then advance to next column.
6786 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
6787 "5:\n"
6788
6789 // Set the LHS and RHS data pointers to the start of the columns just
6790 // computed.
6791 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
6792 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
6793
6794 // Load some parameters needed for the end work on current block.
6795 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
6796 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
6797
6798 // Determine the channel index.
6799 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
6800 "csel w3, %w[row], %w[col], eq\n"
6801
6802 // Offset the bias pointer as needed given the current row, col.
6803 "add x5, x1, x3, lsl #2\n"
6804
6805 // If there is no bias, use no offset, just address the passed zero
6806 // data.
6807 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
6808 "csel x1, x1, x5, eq\n"
6809
6810 // Load 8 bias values.
6811 "ld1 {v14.4s}, [x1], #16\n"
6812 "ld1 {v15.4s}, [x1]\n"
6813
6814 // Now that we know what LHS and RHS data the next iteration of the
6815 // main loop will need to load, we start loading the first 32 bytes of
6816 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
6817 // in the rest of the work on the current block.
6818 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
6819 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
6820 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
6821 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
6822
6823 // Perform the bias-addition.
6824 // Jump based on channel dimension.
6825 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
6826 "bne 6f\n"
6827 // Case where channels are rows
6828 "fadd v16.4s, v16.4s, v14.4s\n"
6829 "fadd v17.4s, v17.4s, v15.4s\n"
6830 "fadd v18.4s, v18.4s, v14.4s\n"
6831 "fadd v19.4s, v19.4s, v15.4s\n"
6832 "fadd v20.4s, v20.4s, v14.4s\n"
6833 "fadd v21.4s, v21.4s, v15.4s\n"
6834 "fadd v22.4s, v22.4s, v14.4s\n"
6835 "fadd v23.4s, v23.4s, v15.4s\n"
6836 "fadd v24.4s, v24.4s, v14.4s\n"
6837 "fadd v25.4s, v25.4s, v15.4s\n"
6838 "fadd v26.4s, v26.4s, v14.4s\n"
6839 "fadd v27.4s, v27.4s, v15.4s\n"
6840 "fadd v28.4s, v28.4s, v14.4s\n"
6841 "fadd v29.4s, v29.4s, v15.4s\n"
6842 "fadd v30.4s, v30.4s, v14.4s\n"
6843 "fadd v31.4s, v31.4s, v15.4s\n"
6844 "b 7f\n"
6845
6846 "6:\n"
6847 // Case where channels are columns
6848 "dup v8.4s, v14.s[0]\n"
6849 "dup v9.4s, v14.s[1]\n"
6850 "dup v10.4s, v14.s[2]\n"
6851 "dup v11.4s, v14.s[3]\n"
6852 "dup v12.4s, v15.s[0]\n"
6853 "dup v13.4s, v15.s[1]\n"
6854 "dup v14.4s, v15.s[2]\n"
6855 "dup v15.4s, v15.s[3]\n"
6856 "fadd v16.4s, v16.4s, v8.4s\n"
6857 "fadd v17.4s, v17.4s, v8.4s\n"
6858 "fadd v18.4s, v18.4s, v9.4s\n"
6859 "fadd v19.4s, v19.4s, v9.4s\n"
6860 "fadd v20.4s, v20.4s, v10.4s\n"
6861 "fadd v21.4s, v21.4s, v10.4s\n"
6862 "fadd v22.4s, v22.4s, v11.4s\n"
6863 "fadd v23.4s, v23.4s, v11.4s\n"
6864 "fadd v24.4s, v24.4s, v12.4s\n"
6865 "fadd v25.4s, v25.4s, v12.4s\n"
6866 "fadd v26.4s, v26.4s, v13.4s\n"
6867 "fadd v27.4s, v27.4s, v13.4s\n"
6868 "fadd v28.4s, v28.4s, v14.4s\n"
6869 "fadd v29.4s, v29.4s, v14.4s\n"
6870 "fadd v30.4s, v30.4s, v15.4s\n"
6871 "fadd v31.4s, v31.4s, v15.4s\n"
6872 "7:\n"
6873
6874 // Load the clamp_min, clamp_max bounds
6875 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
6876 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
6877 "dup v14.4s, w2\n" // clamp_min
6878 "dup v15.4s, w3\n" // clamp_max
6879
6880 // Apply the clamp_min bound
6881 "fmax v16.4s, v16.4s, v14.4s\n"
6882 "fmax v17.4s, v17.4s, v14.4s\n"
6883 "fmax v18.4s, v18.4s, v14.4s\n"
6884 "fmax v19.4s, v19.4s, v14.4s\n"
6885 "fmax v20.4s, v20.4s, v14.4s\n"
6886 "fmax v21.4s, v21.4s, v14.4s\n"
6887 "fmax v22.4s, v22.4s, v14.4s\n"
6888 "fmax v23.4s, v23.4s, v14.4s\n"
6889 "fmax v24.4s, v24.4s, v14.4s\n"
6890 "fmax v25.4s, v25.4s, v14.4s\n"
6891 "fmax v26.4s, v26.4s, v14.4s\n"
6892 "fmax v27.4s, v27.4s, v14.4s\n"
6893 "fmax v28.4s, v28.4s, v14.4s\n"
6894 "fmax v29.4s, v29.4s, v14.4s\n"
6895 "fmax v30.4s, v30.4s, v14.4s\n"
6896 "fmax v31.4s, v31.4s, v14.4s\n"
6897
6898 // Apply the clamp_max bound
6899 "fmin v16.4s, v16.4s, v15.4s\n"
6900 "fmin v17.4s, v17.4s, v15.4s\n"
6901 "fmin v18.4s, v18.4s, v15.4s\n"
6902 "fmin v19.4s, v19.4s, v15.4s\n"
6903 "fmin v20.4s, v20.4s, v15.4s\n"
6904 "fmin v21.4s, v21.4s, v15.4s\n"
6905 "fmin v22.4s, v22.4s, v15.4s\n"
6906 "fmin v23.4s, v23.4s, v15.4s\n"
6907 "fmin v24.4s, v24.4s, v15.4s\n"
6908 "fmin v25.4s, v25.4s, v15.4s\n"
6909 "fmin v26.4s, v26.4s, v15.4s\n"
6910 "fmin v27.4s, v27.4s, v15.4s\n"
6911 "fmin v28.4s, v28.4s, v15.4s\n"
6912 "fmin v29.4s, v29.4s, v15.4s\n"
6913 "fmin v30.4s, v30.4s, v15.4s\n"
6914 "fmin v31.4s, v31.4s, v15.4s\n"
6915
6916 // Compute how much of the 8x8 block of destination 8bit values that
6917 // we have computed, fit in the destination matrix. Typically, all of
6918 // it fits, but when the destination matrix shape is not a multiple
6919 // of 8x8, there are some 8x8 blocks along the boundaries that do
6920 // not fit entirely.
6921 "sub w1, %w[dst_rows], %w[row]\n"
6922 "sub w2, %w[dst_cols], %w[col]\n"
6923 "mov w3, #8\n"
6924 "cmp w1, #8\n"
6925 // Compute w1 = how many rows of the 8x8 block fit
6926 "csel w1, w1, w3, le\n"
6927 "cmp w2, #8\n"
6928 // Compute w2 = how many cols of the 8x8 block fit
6929 "csel w2, w2, w3, le\n"
6930
6931 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
6932 "cmp w1, w3\n"
6933 "ccmp w2, w3, 0, eq\n"
6934 // Yes, all of the 8x8 block fits, go to fast path.
6935 "beq 30f\n"
6936 // Not all of the 8x8 block fits.
6937 // Set (x3 address, x4 stride) to write to dst_tmp_buf
6938 "mov x3, %[dst_tmp_buf]\n"
6939 "mov x4, #32\n"
6940 "b 31f\n"
6941 "30:\n"
6942 // Yes, all of the 8x8 block fits.
6943 // Set (x3 address, x4 stride) to write directly to destination matrix.
6944 "mov x3, %[dst_ptr]\n"
6945 "mov x4, x11\n"
6946 "31:\n"
6947
6948 // Write our 8bit values to the destination described by
6949 // (x3 address, x4 stride).
6950 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6951 "str q16, [x3, #0]\n"
6952 "str q17, [x3, #16]\n"
6953 "add x3, x3, x4\n"
6954 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6955 RUY_MAKE_ZERO(v16)
6956 RUY_MAKE_ZERO(v17)
6957 "str q18, [x3, #0]\n"
6958 "str q19, [x3, #16]\n"
6959 "add x3, x3, x4\n"
6960 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6961 RUY_MAKE_ZERO(v18)
6962 RUY_MAKE_ZERO(v19)
6963 "str q20, [x3, #0]\n"
6964 "str q21, [x3, #16]\n"
6965 "add x3, x3, x4\n"
6966 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6967 RUY_MAKE_ZERO(v20)
6968 RUY_MAKE_ZERO(v21)
6969 "str q22, [x3, #0]\n"
6970 "str q23, [x3, #16]\n"
6971 "add x3, x3, x4\n"
6972 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6973 RUY_MAKE_ZERO(v22)
6974 RUY_MAKE_ZERO(v23)
6975 "str q24, [x3, #0]\n"
6976 "str q25, [x3, #16]\n"
6977 "add x3, x3, x4\n"
6978 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6979 RUY_MAKE_ZERO(v24)
6980 RUY_MAKE_ZERO(v25)
6981 "str q26, [x3, #0]\n"
6982 "str q27, [x3, #16]\n"
6983 "add x3, x3, x4\n"
6984 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6985 RUY_MAKE_ZERO(v26)
6986 RUY_MAKE_ZERO(v27)
6987 "str q28, [x3, #0]\n"
6988 "str q29, [x3, #16]\n"
6989 "add x3, x3, x4\n"
6990 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
6991 RUY_MAKE_ZERO(v28)
6992 RUY_MAKE_ZERO(v29)
6993 "str q30, [x3, #0]\n"
6994 "str q31, [x3, #16]\n"
6995 RUY_MAKE_ZERO(v30)
6996 RUY_MAKE_ZERO(v31)
6997
6998 // If all of the 8x8 block fits, we just finished writing it to the
6999 // destination, so we skip the next part.
7000 "beq 41f\n"
7001 // Not all of the 8x8 block fits in the destination matrix. We just
7002 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
7003 // it to copy into the destination matrix the part that fits.
7004 "mov x3, %[dst_tmp_buf]\n"
7005 "mov x4, %[dst_ptr]\n"
7006 "mov w6, #0\n"
7007 "50:\n"
7008 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
7009 "mov w5, #0\n"
7010 "51:\n"
7011 "ldr w7, [x3, x5, lsl #2]\n"
7012 "str w7, [x4, x5, lsl #2]\n"
7013 "add w5, w5, #1\n"
7014 "cmp w5, w1\n"
7015 "blt 51b\n"
7016 "add w6, w6, #1\n"
7017 "add x3, x3, #32\n"
7018 "add x4, x4, x11\n"
7019 "cmp w6, w2\n"
7020 "blt 50b\n"
7021 "41:\n"
7022 "add %[dst_ptr], %[dst_ptr], #32\n"
7023 // At this point we have completely finished writing values to the
7024 // destination matrix for the current block.
7025
7026 // Reload some params --- we had used x5 -- x7 for a few other things
7027 // since the last time we had loaded them.
7028 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7029 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7030 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7031
7032 // Move to the next block of the destination matrix, for the next iter
7033 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
7034 // been updated earlier.
7035 // Have we reached the end row?
7036 "cmp %w[row], w7\n"
7037 "beq 20f\n" // yes, end row.
7038 // Not end row. Move to the next row.
7039 "add %w[row], %w[row], #8\n"
7040 "b 21f\n"
7041 "20:\n"
7042 // Was already at end row.
7043 "mov %w[row], w6\n" // Move back to first row.
7044 "add %w[col], %w[col], #8\n" // Move to the next column.
7045 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
7046 "mov %[dst_ptr], %[dst_col_ptr]\n"
7047 "21:\n"
7048
7049 // Main loop exit condition: have we hit the end column?
7050 "cmp %w[col], w8\n"
7051
7052 // w1 is the number of levels of depth that we have already loaded
7053 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7054 // above, this is currently 1.
7055 "mov w1, #1\n"
7056
7057 "ble 1b\n"
7058
7059 // clang-format on
7060
7061 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
7062 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
7063 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
7064 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
7065 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
7066 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
7067 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
7068 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
7069 "v26", "v27", "v28", "v29", "v30", "v31");
7070 }
7071
7072 // Variant of KernelFloatNeon tuned for in-order CPUs that do not
7073 // support dotprod (while dotprod by itself is not relevant to floating-point,
7074 // this additional bit of information that we have about the target happens to
7075 // be useful here).
7076 //
7077 // So a typical target CPU here would be ARM Cortex-A53 or the original
7078 // Cortex-A55.
7079 //
7080 // This kernel is similar to and inspired by gemmlowp's
7081 // NEON_64bit_GEMM_Float32_WithScalar_A53.
7082 // which was contributed by David Mansell with very helpful
7083 // comments. Specifically, see this comment about tuning for Cortex-A53:
7084 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
KernelFloatNeonA55ish(const KernelParamsFloat<8,8> & params)7085 void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params) {
7086 profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
7087
7088 CheckOffsetsInKernelParamsFloat(params);
7089
7090 const float* lhs_col_ptr = params.lhs_base_ptr;
7091 const float* rhs_col_ptr = params.rhs_base_ptr;
7092 const float* lhs_ptr = lhs_col_ptr;
7093 const float* rhs_ptr = rhs_col_ptr;
7094 float* dst_col_ptr = params.dst_base_ptr;
7095 float* dst_ptr = dst_col_ptr;
7096 int row = params.start_row;
7097 int col = params.start_col;
7098
7099 // The asm kernel below has the following NEON register allocation:
7100 //
7101 // v16 -- v31 are accumulators.
7102 // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
7103 //
7104 // RHS 1x8 block
7105 // /-----------------------------------------|
7106 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
7107 // \-----------------------------------------/
7108 // LHS 8x1 block
7109 // /---------------------\ /-----------------------------------------|
7110 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
7111 // | ... | | ... ... |
7112 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
7113 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
7114 // | ... | | ... ... |
7115 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
7116 // \---------------------/ \-----------------------------------------/
7117 // accumulators 8x8 block
7118 //
7119 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
7120 // we did not observe a benefit of such partial unrolling on in-order CPUs.
7121 //
7122 // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
7123 // for the post-accumulation part of the kernel.
7124 asm volatile(
7125 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
7126
7127 // clang-format off
7128
7129 // Load some parameters into registers.
7130 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7131 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7132 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7133 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
7134 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
7135 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
7136 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
7137 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
7138
7139
7140 // Clear accumulators.
7141 RUY_MAKE_ZERO(v16)
7142 // Load the first 32 bytes of LHS and RHS data.
7143 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7144 RUY_MAKE_ZERO(v17)
7145 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7146 RUY_MAKE_ZERO(v18)
7147 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7148 RUY_MAKE_ZERO(v19)
7149 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7150 RUY_MAKE_ZERO(v20)
7151 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
7152 RUY_MAKE_ZERO(v21)
7153 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
7154 RUY_MAKE_ZERO(v22)
7155 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
7156 RUY_MAKE_ZERO(v23)
7157 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
7158 RUY_MAKE_ZERO(v24)
7159 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
7160 RUY_MAKE_ZERO(v25)
7161 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
7162 RUY_MAKE_ZERO(v26)
7163 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7164 RUY_MAKE_ZERO(v27)
7165 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7166 RUY_MAKE_ZERO(v28)
7167 RUY_MAKE_ZERO(v29)
7168 RUY_MAKE_ZERO(v30)
7169 RUY_MAKE_ZERO(v31)
7170
7171 // w1 is the number of levels of depth that remain to load
7172 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7173 // above, this is currently depth - 1.
7174 "sub w1, w12, #1\n"
7175
7176 // Main loop of the whole GEMM, over rows and columns of the
7177 // destination matrix.
7178 "1:\n"
7179
7180 "cmp w1, #0\n"
7181 "fmla v16.4s, v0.4s, v2.s[0]\n"
7182 "fmla v18.4s, v0.4s, v2.s[1]\n"
7183 "fmla v20.4s, v0.4s, v2.s[2]\n"
7184 "fmla v22.4s, v0.4s, v2.s[3]\n"
7185
7186 // Accumulation loop
7187 "beq 79f\n"
7188
7189 "2:\n"
7190
7191 "fmla v24.4s, v0.4s, v3.s[0]\n"
7192 "ldr x2, [%[lhs_ptr], #8]\n"
7193 "fmla v26.4s, v0.4s, v3.s[1]\n"
7194 "ldr x3, [%[lhs_ptr], #24]\n"
7195 "fmla v28.4s, v0.4s, v3.s[2]\n"
7196 "ldr x5, [%[rhs_ptr], #24]\n"
7197 "fmla v30.4s, v0.4s, v3.s[3]\n"
7198 "ldr x4, [%[rhs_ptr], #8]\n"
7199 "fmla v25.4s, v1.4s, v3.s[0]\n"
7200 "subs w1, w1, #1\n"
7201 "ldr d0, [%[lhs_ptr]], #32\n"
7202 "fmla v27.4s, v1.4s, v3.s[1]\n"
7203 "fmla v29.4s, v1.4s, v3.s[2]\n"
7204 "fmla v31.4s, v1.4s, v3.s[3]\n"
7205 "ins v0.d[1], x2\n"
7206 "ldr d3, [%[rhs_ptr], #16]\n"
7207 "fmla v17.4s, v1.4s, v2.s[0]\n"
7208 "fmla v19.4s, v1.4s, v2.s[1]\n"
7209 "ins v3.d[1], x5\n"
7210 "ldr d4, [%[rhs_ptr]], #32\n"
7211 "fmla v21.4s, v1.4s, v2.s[2]\n"
7212 "fmla v23.4s, v1.4s, v2.s[3]\n"
7213 "fmla v16.4s, v0.4s, v4.s[0]\n"
7214 "ins v4.d[1], x4\n"
7215 "ldr d1, [%[lhs_ptr], #-16]\n"
7216 "fmla v18.4s, v0.4s, v4.s[1]\n"
7217 "fmla v20.4s, v0.4s, v4.s[2]\n"
7218 "ins v1.d[1], x3\n"
7219 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7220 "mov v2.16b, v4.16b\n"
7221 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7222 "fmla v22.4s, v0.4s, v4.s[3]\n"
7223 "bne 2b\n"
7224
7225 "79:\n"
7226
7227 // End of the inner loop on depth. Now perform the remaining
7228 // multiply-adds of the last level of depth, for which the LHS
7229 // and RHS data is already loaded.
7230
7231 "fmla v24.4s, v0.4s, v3.s[0]\n"
7232 "fmla v26.4s, v0.4s, v3.s[1]\n"
7233 "fmla v28.4s, v0.4s, v3.s[2]\n"
7234 "fmla v30.4s, v0.4s, v3.s[3]\n"
7235 "fmla v25.4s, v1.4s, v3.s[0]\n"
7236 "fmla v27.4s, v1.4s, v3.s[1]\n"
7237 "fmla v29.4s, v1.4s, v3.s[2]\n"
7238 "fmla v31.4s, v1.4s, v3.s[3]\n"
7239 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7240 "fmla v17.4s, v1.4s, v2.s[0]\n"
7241 "fmla v19.4s, v1.4s, v2.s[1]\n"
7242 "fmla v21.4s, v1.4s, v2.s[2]\n"
7243 "fmla v23.4s, v1.4s, v2.s[3]\n"
7244
7245 // End of accumulation. The registers v16 -- v31 contain the final
7246 // int32 accumulator values of the current 8x8 destination block.
7247 // We now have to compute the final 8-bit values from these int32
7248 // accumulators, and advance to the next 8x8 block. We intertwine
7249 // these two aspects whenever possible for optimal pipelining, both
7250 // at the data flow level (prefetch data for next block as early as
7251 // possible) and instruction pipelining level (some of the next-block
7252 // work can dual-issue with some of the final work on the current
7253 // block).
7254
7255 // Logic to advance to the next block in preparation for the next
7256 // iteration of the main loop. For now, we only want to compute
7257 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
7258 // not yet ready to update the values of row and col, as we still need
7259 // the current values for the rest of the work on the current block.
7260
7261 "cmp %w[row], w7\n" // Have we finished the last row?
7262 "bge 4f\n" // If finished last row, go to 4
7263 // Not finished last row: then advance to next row.
7264 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
7265 "b 5f\n"
7266 "4:\n" // Finished last row...
7267 "mov %[lhs_col_ptr], x5\n" // Go back to first row
7268 // Now we need to advance to the next column. If we already
7269 // finished the last column, then in principle we are done, however
7270 // we can't just return here, as we need to allow the end work of the
7271 // current block to complete. The good news is that at this point it
7272 // doesn't matter what data we load for the next column, since
7273 // we will exit from the main loop below before actually storing
7274 // anything computed from that data.
7275 "cmp %w[col], w8\n" // Have we finished the last column?
7276 "bge 5f\n" // If yes, just carry on without updating the column pointer.
7277 // Not finished last column: then advance to next column.
7278 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
7279 "5:\n"
7280
7281 // Set the LHS and RHS data pointers to the start of the columns just
7282 // computed.
7283 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
7284 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
7285
7286 // Load some parameters needed for the end work on current block.
7287 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
7288 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
7289
7290 // Determine the channel index.
7291 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
7292 "csel w3, %w[row], %w[col], eq\n"
7293
7294 // Offset the bias pointer as needed given the current row, col.
7295 "add x5, x1, x3, lsl #2\n"
7296
7297 // If there is no bias, use no offset, just address the passed zero
7298 // data.
7299
7300 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
7301 "csel x1, x1, x5, eq\n"
7302
7303 // Load 8 bias values.
7304 "ld1 {v14.4s}, [x1], #16\n"
7305 "ld1 {v15.4s}, [x1]\n"
7306
7307 // Now that we know what LHS and RHS data the next iteration of the
7308 // main loop will need to load, we start loading the first 32 bytes of
7309 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
7310 // in the rest of the work on the current block.
7311 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7312 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7313 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7314 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7315
7316 // Perform the bias-addition.
7317 // Jump based on channel dimension.
7318 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
7319 "bne 6f\n"
7320 // Case where channels are rows
7321 "fadd v16.4s, v16.4s, v14.4s\n"
7322 "fadd v17.4s, v17.4s, v15.4s\n"
7323 "fadd v18.4s, v18.4s, v14.4s\n"
7324 "fadd v19.4s, v19.4s, v15.4s\n"
7325 "fadd v20.4s, v20.4s, v14.4s\n"
7326 "fadd v21.4s, v21.4s, v15.4s\n"
7327 "fadd v22.4s, v22.4s, v14.4s\n"
7328 "fadd v23.4s, v23.4s, v15.4s\n"
7329 "fadd v24.4s, v24.4s, v14.4s\n"
7330 "fadd v25.4s, v25.4s, v15.4s\n"
7331 "fadd v26.4s, v26.4s, v14.4s\n"
7332 "fadd v27.4s, v27.4s, v15.4s\n"
7333 "fadd v28.4s, v28.4s, v14.4s\n"
7334 "fadd v29.4s, v29.4s, v15.4s\n"
7335 "fadd v30.4s, v30.4s, v14.4s\n"
7336 "fadd v31.4s, v31.4s, v15.4s\n"
7337 "b 7f\n"
7338
7339 "6:\n"
7340 // Case where channels are columns
7341 "dup v8.4s, v14.s[0]\n"
7342 "dup v9.4s, v14.s[1]\n"
7343 "fadd v16.4s, v16.4s, v8.4s\n"
7344 "dup v10.4s, v14.s[2]\n"
7345 "fadd v17.4s, v17.4s, v8.4s\n"
7346 "dup v11.4s, v14.s[3]\n"
7347 "fadd v18.4s, v18.4s, v9.4s\n"
7348 "dup v12.4s, v15.s[0]\n"
7349 "fadd v19.4s, v19.4s, v9.4s\n"
7350 "dup v13.4s, v15.s[1]\n"
7351 "fadd v20.4s, v20.4s, v10.4s\n"
7352 "dup v14.4s, v15.s[2]\n"
7353 "fadd v21.4s, v21.4s, v10.4s\n"
7354 "dup v15.4s, v15.s[3]\n"
7355 "fadd v22.4s, v22.4s, v11.4s\n"
7356 "fadd v23.4s, v23.4s, v11.4s\n"
7357 "fadd v24.4s, v24.4s, v12.4s\n"
7358 "fadd v25.4s, v25.4s, v12.4s\n"
7359 "fadd v26.4s, v26.4s, v13.4s\n"
7360 "fadd v27.4s, v27.4s, v13.4s\n"
7361 "fadd v28.4s, v28.4s, v14.4s\n"
7362 "fadd v29.4s, v29.4s, v14.4s\n"
7363 "fadd v30.4s, v30.4s, v15.4s\n"
7364 "fadd v31.4s, v31.4s, v15.4s\n"
7365 "7:\n"
7366
7367 // Load the clamp_min, clamp_max bounds
7368 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
7369 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
7370 "dup v14.4s, w2\n" // clamp_min
7371 "dup v15.4s, w3\n" // clamp_max
7372
7373 // Apply the clamp_min bound
7374 "fmax v16.4s, v16.4s, v14.4s\n"
7375 "fmax v17.4s, v17.4s, v14.4s\n"
7376 "fmax v18.4s, v18.4s, v14.4s\n"
7377 "fmax v19.4s, v19.4s, v14.4s\n"
7378 "fmax v20.4s, v20.4s, v14.4s\n"
7379 "fmax v21.4s, v21.4s, v14.4s\n"
7380 "fmax v22.4s, v22.4s, v14.4s\n"
7381 "fmax v23.4s, v23.4s, v14.4s\n"
7382 "fmax v24.4s, v24.4s, v14.4s\n"
7383 "fmax v25.4s, v25.4s, v14.4s\n"
7384 "fmax v26.4s, v26.4s, v14.4s\n"
7385 "fmax v27.4s, v27.4s, v14.4s\n"
7386 "fmax v28.4s, v28.4s, v14.4s\n"
7387 "fmax v29.4s, v29.4s, v14.4s\n"
7388 "fmax v30.4s, v30.4s, v14.4s\n"
7389 "fmax v31.4s, v31.4s, v14.4s\n"
7390
7391 // Apply the clamp_max bound
7392 "fmin v16.4s, v16.4s, v15.4s\n"
7393 "fmin v17.4s, v17.4s, v15.4s\n"
7394 "fmin v18.4s, v18.4s, v15.4s\n"
7395 "fmin v19.4s, v19.4s, v15.4s\n"
7396 "fmin v20.4s, v20.4s, v15.4s\n"
7397 "fmin v21.4s, v21.4s, v15.4s\n"
7398 "fmin v22.4s, v22.4s, v15.4s\n"
7399 "fmin v23.4s, v23.4s, v15.4s\n"
7400 "fmin v24.4s, v24.4s, v15.4s\n"
7401 "fmin v25.4s, v25.4s, v15.4s\n"
7402 "fmin v26.4s, v26.4s, v15.4s\n"
7403 "fmin v27.4s, v27.4s, v15.4s\n"
7404 "fmin v28.4s, v28.4s, v15.4s\n"
7405 "fmin v29.4s, v29.4s, v15.4s\n"
7406 "fmin v30.4s, v30.4s, v15.4s\n"
7407 "fmin v31.4s, v31.4s, v15.4s\n"
7408
7409 // Compute how much of the 8x8 block of destination 8bit values that
7410 // we have computed, fit in the destination matrix. Typically, all of
7411 // it fits, but when the destination matrix shape is not a multiple
7412 // of 8x8, there are some 8x8 blocks along the boundaries that do
7413 // not fit entirely.
7414 "sub w1, %w[dst_rows], %w[row]\n"
7415 "sub w2, %w[dst_cols], %w[col]\n"
7416 "mov w3, #8\n"
7417 "cmp w1, #8\n"
7418 // Compute w1 = how many rows of the 8x8 block fit
7419 "csel w1, w1, w3, le\n"
7420 "cmp w2, #8\n"
7421 // Compute w2 = how many cols of the 8x8 block fit
7422 "csel w2, w2, w3, le\n"
7423
7424 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
7425 "cmp w1, w3\n"
7426 "ccmp w2, w3, 0, eq\n"
7427 // Yes, all of the 8x8 block fits, go to fast path.
7428 "beq 30f\n"
7429 // Not all of the 8x8 block fits.
7430 // Set (x3 address, x4 stride) to write to dst_tmp_buf
7431 "mov x3, %[dst_tmp_buf]\n"
7432 "mov x4, #32\n"
7433 "b 31f\n"
7434 "30:\n"
7435 // Yes, all of the 8x8 block fits.
7436 // Set (x3 address, x4 stride) to write directly to destination matrix.
7437 "mov x3, %[dst_ptr]\n"
7438 "mov x4, x11\n"
7439 "31:\n"
7440
7441 // Write our 8bit values to the destination described by
7442 // (x3 address, x4 stride).
7443 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7444 "str q16, [x3, #0]\n"
7445 "str q17, [x3, #16]\n"
7446 "add x3, x3, x4\n"
7447 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7448 RUY_MAKE_ZERO(v16)
7449 RUY_MAKE_ZERO(v17)
7450 "str q18, [x3, #0]\n"
7451 "str q19, [x3, #16]\n"
7452 "add x3, x3, x4\n"
7453 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7454 RUY_MAKE_ZERO(v18)
7455 RUY_MAKE_ZERO(v19)
7456 "str q20, [x3, #0]\n"
7457 "str q21, [x3, #16]\n"
7458 "add x3, x3, x4\n"
7459 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7460 RUY_MAKE_ZERO(v20)
7461 RUY_MAKE_ZERO(v21)
7462 "str q22, [x3, #0]\n"
7463 "str q23, [x3, #16]\n"
7464 "add x3, x3, x4\n"
7465 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7466 RUY_MAKE_ZERO(v22)
7467 RUY_MAKE_ZERO(v23)
7468 "str q24, [x3, #0]\n"
7469 "str q25, [x3, #16]\n"
7470 "add x3, x3, x4\n"
7471 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7472 RUY_MAKE_ZERO(v24)
7473 RUY_MAKE_ZERO(v25)
7474 "str q26, [x3, #0]\n"
7475 "str q27, [x3, #16]\n"
7476 "add x3, x3, x4\n"
7477 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7478 RUY_MAKE_ZERO(v26)
7479 RUY_MAKE_ZERO(v27)
7480 "str q28, [x3, #0]\n"
7481 "str q29, [x3, #16]\n"
7482 "add x3, x3, x4\n"
7483 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7484 RUY_MAKE_ZERO(v28)
7485 RUY_MAKE_ZERO(v29)
7486 "str q30, [x3, #0]\n"
7487 "str q31, [x3, #16]\n"
7488 RUY_MAKE_ZERO(v30)
7489 RUY_MAKE_ZERO(v31)
7490
7491 // If all of the 8x8 block fits, we just finished writing it to the
7492 // destination, so we skip the next part.
7493 "beq 41f\n"
7494 // Not all of the 8x8 block fits in the destination matrix. We just
7495 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
7496 // it to copy into the destination matrix the part that fits.
7497 "mov x3, %[dst_tmp_buf]\n"
7498 "mov x4, %[dst_ptr]\n"
7499 "mov w6, #0\n"
7500 "50:\n"
7501 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
7502 "mov w5, #0\n"
7503 "51:\n"
7504 "ldr w7, [x3, x5, lsl #2]\n"
7505 "str w7, [x4, x5, lsl #2]\n"
7506 "add w5, w5, #1\n"
7507 "cmp w5, w1\n"
7508 "blt 51b\n"
7509 "add w6, w6, #1\n"
7510 "add x3, x3, #32\n"
7511 "add x4, x4, x11\n"
7512 "cmp w6, w2\n"
7513 "blt 50b\n"
7514 "41:\n"
7515 "add %[dst_ptr], %[dst_ptr], #32\n"
7516 // At this point we have completely finished writing values to the
7517 // destination matrix for the current block.
7518
7519 // Reload some params --- we had used x5 -- x7 for a few other things
7520 // since the last time we had loaded them.
7521 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7522 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7523 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7524
7525 // Move to the next block of the destination matrix, for the next iter
7526 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
7527 // been updated earlier.
7528 // Have we reached the end row?
7529 "cmp %w[row], w7\n"
7530 "beq 20f\n" // yes, end row.
7531 // Not end row. Move to the next row.
7532 "add %w[row], %w[row], #8\n"
7533 "b 21f\n"
7534 "20:\n"
7535 // Was already at end row.
7536 "mov %w[row], w6\n" // Move back to first row.
7537 "add %w[col], %w[col], #8\n" // Move to the next column.
7538 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
7539 "mov %[dst_ptr], %[dst_col_ptr]\n"
7540 "21:\n"
7541
7542 // Main loop exit condition: have we hit the end column?
7543 "cmp %w[col], w8\n"
7544
7545 // w1 is the number of levels of depth that remain to load
7546 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7547 // above, this is currently depth - 1.
7548 "sub w1, w12, #1\n"
7549
7550 "ble 1b\n"
7551
7552 // clang-format on
7553
7554 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
7555 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
7556 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
7557 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
7558 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
7559 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
7560 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
7561 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
7562 "v26", "v27", "v28", "v29", "v30", "v31");
7563 }
7564
7565 // Variant of KernelFloatNeonA55ish tuned for in-order CPUs that do
7566 // support dotprod (while dotprod by itself is not relevant to floating-point,
7567 // this additional bit of information that we have about the target happens to
7568 // be useful here).
7569 //
7570 // So a typical target CPU here would be ARM Cortex-A55r1.
7571 //
7572 // This kernel is similar to and inspired by gemmlowp's
7573 // NEON_64bit_GEMM_Float32_WithScalar_A55r1.
7574 // which was contributed by David Mansell with very helpful
7575 // comments. Specifically, see this comment about tuning for Cortex-A55r1:
7576 // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8,8> & params)7577 void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params) {
7578 profiler::ScopeLabel label(
7579 "Kernel (kNeonDotprod, optimized for in-order cores)");
7580
7581 CheckOffsetsInKernelParamsFloat(params);
7582
7583 const float* lhs_col_ptr = params.lhs_base_ptr;
7584 const float* rhs_col_ptr = params.rhs_base_ptr;
7585 const float* lhs_ptr = lhs_col_ptr;
7586 const float* rhs_ptr = rhs_col_ptr;
7587 float* dst_col_ptr = params.dst_base_ptr;
7588 float* dst_ptr = dst_col_ptr;
7589 int row = params.start_row;
7590 int col = params.start_col;
7591
7592 // The asm kernel below has the following NEON register allocation:
7593 //
7594 // v16 -- v31 are accumulators.
7595 // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
7596 //
7597 // RHS 1x8 block
7598 // /-----------------------------------------|
7599 // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
7600 // \-----------------------------------------/
7601 // LHS 8x1 block
7602 // /---------------------\ /-----------------------------------------|
7603 // | v0.s[0] | |v16.s[0] ... v30.s[0]|
7604 // | ... | | ... ... |
7605 // | v0.s[3] | |v16.s[3] ... v30.s[3]|
7606 // | v1.s[0] | |v17.s[0] ... v31.s[0]|
7607 // | ... | | ... ... |
7608 // | v1.s[3] | |v17.s[3] ... v31.s[3]|
7609 // \---------------------/ \-----------------------------------------/
7610 // accumulators 8x8 block
7611 //
7612 // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
7613 // we did not observe a benefit of such partial unrolling on in-order CPUs.
7614 //
7615 // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
7616 // for the post-accumulation part of the kernel.
7617 asm volatile(
7618 #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n"
7619
7620 // clang-format off
7621
7622 // Load some parameters into registers.
7623 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7624 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
7625 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
7626 "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
7627 "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
7628 "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
7629 "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
7630 "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
7631
7632
7633 // Clear accumulators.
7634 RUY_MAKE_ZERO(v16)
7635 // Load the first 32 bytes of LHS and RHS data.
7636 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7637 RUY_MAKE_ZERO(v17)
7638 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7639 RUY_MAKE_ZERO(v18)
7640 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7641 RUY_MAKE_ZERO(v19)
7642 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7643 RUY_MAKE_ZERO(v20)
7644 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
7645 RUY_MAKE_ZERO(v21)
7646 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
7647 RUY_MAKE_ZERO(v22)
7648 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
7649 RUY_MAKE_ZERO(v23)
7650 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
7651 RUY_MAKE_ZERO(v24)
7652 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
7653 RUY_MAKE_ZERO(v25)
7654 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
7655 RUY_MAKE_ZERO(v26)
7656 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7657 RUY_MAKE_ZERO(v27)
7658 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7659 RUY_MAKE_ZERO(v28)
7660 RUY_MAKE_ZERO(v29)
7661 RUY_MAKE_ZERO(v30)
7662 RUY_MAKE_ZERO(v31)
7663
7664 // w1 is the number of levels of depth that remain to load
7665 // LHS and RHS data for. Corresponding to the initial ld1 instructions
7666 // above, this is currently depth - 1.
7667 "sub w1, w12, #1\n"
7668
7669 // Main loop of the whole GEMM, over rows and columns of the
7670 // destination matrix.
7671 "1:\n"
7672
7673 "cmp w1, #0\n"
7674 "fmla v16.4s, v0.4s, v2.s[0]\n"
7675 "fmla v18.4s, v0.4s, v2.s[1]\n"
7676 "fmla v20.4s, v0.4s, v2.s[2]\n"
7677 "fmla v22.4s, v0.4s, v2.s[3]\n"
7678
7679 // Accumulation loop
7680 "beq 79f\n"
7681
7682 "2:\n"
7683
7684 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
7685 "fmla v24.4s, v0.4s, v3.s[0]\n"
7686 "ldr x2, [%[lhs_ptr], #8]\n"
7687 "fmla v26.4s, v0.4s, v3.s[1]\n"
7688 "ldr x3, [%[lhs_ptr], #24]\n"
7689 "fmla v28.4s, v0.4s, v3.s[2]\n"
7690 "ldr x5, [%[rhs_ptr], #24]\n"
7691 "fmla v30.4s, v0.4s, v3.s[3]\n"
7692 "ldr d0, [%[lhs_ptr]], #32\n"
7693 "fmla v25.4s, v1.4s, v3.s[0]\n"
7694 "ldr x4, [%[rhs_ptr], #8]\n"
7695 "fmla v27.4s, v1.4s, v3.s[1]\n"
7696 "subs w1, w1, #1\n"
7697 "fmla v29.4s, v1.4s, v3.s[2]\n"
7698 "ins v0.d[1], x2\n"
7699 "fmla v31.4s, v1.4s, v3.s[3]\n"
7700 "ldr d3, [%[rhs_ptr], #16]\n"
7701 "fmla v17.4s, v1.4s, v2.s[0]\n"
7702 "ins v3.d[1], x5\n"
7703 "fmla v19.4s, v1.4s, v2.s[1]\n"
7704 "ldr d4, [%[rhs_ptr]], #32\n"
7705 "fmla v21.4s, v1.4s, v2.s[2]\n"
7706 "ins v4.d[1], x4\n"
7707 "fmla v23.4s, v1.4s, v2.s[3]\n"
7708 RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
7709 "fmla v16.4s, v0.4s, v4.s[0]\n"
7710 "ldr d1, [%[lhs_ptr], #-16]\n"
7711 "fmla v18.4s, v0.4s, v4.s[1]\n"
7712 "ins v1.d[1], x3\n"
7713 "fmla v20.4s, v0.4s, v4.s[2]\n"
7714 "mov v2.16b, v4.16b\n"
7715 "fmla v22.4s, v0.4s, v4.s[3]\n"
7716 "bne 2b\n"
7717
7718 "79:\n"
7719
7720 // End of the inner loop on depth. Now perform the remaining
7721 // multiply-adds of the last level of depth, for which the LHS
7722 // and RHS data is already loaded.
7723
7724 "fmla v24.4s, v0.4s, v3.s[0]\n"
7725 "fmla v26.4s, v0.4s, v3.s[1]\n"
7726 "fmla v28.4s, v0.4s, v3.s[2]\n"
7727 "fmla v30.4s, v0.4s, v3.s[3]\n"
7728 "fmla v25.4s, v1.4s, v3.s[0]\n"
7729 "fmla v27.4s, v1.4s, v3.s[1]\n"
7730 "fmla v29.4s, v1.4s, v3.s[2]\n"
7731 "fmla v31.4s, v1.4s, v3.s[3]\n"
7732 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
7733 "fmla v17.4s, v1.4s, v2.s[0]\n"
7734 "fmla v19.4s, v1.4s, v2.s[1]\n"
7735 "fmla v21.4s, v1.4s, v2.s[2]\n"
7736 "fmla v23.4s, v1.4s, v2.s[3]\n"
7737
7738 // End of accumulation. The registers v16 -- v31 contain the final
7739 // int32 accumulator values of the current 8x8 destination block.
7740 // We now have to compute the final 8-bit values from these int32
7741 // accumulators, and advance to the next 8x8 block. We intertwine
7742 // these two aspects whenever possible for optimal pipelining, both
7743 // at the data flow level (prefetch data for next block as early as
7744 // possible) and instruction pipelining level (some of the next-block
7745 // work can dual-issue with some of the final work on the current
7746 // block).
7747
7748 // Logic to advance to the next block in preparation for the next
7749 // iteration of the main loop. For now, we only want to compute
7750 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
7751 // not yet ready to update the values of row and col, as we still need
7752 // the current values for the rest of the work on the current block.
7753
7754 "cmp %w[row], w7\n" // Have we finished the last row?
7755 "bge 4f\n" // If finished last row, go to 4
7756 // Not finished last row: then advance to next row.
7757 "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
7758 "b 5f\n"
7759 "4:\n" // Finished last row...
7760 "mov %[lhs_col_ptr], x5\n" // Go back to first row
7761 // Now we need to advance to the next column. If we already
7762 // finished the last column, then in principle we are done, however
7763 // we can't just return here, as we need to allow the end work of the
7764 // current block to complete. The good news is that at this point it
7765 // doesn't matter what data we load for the next column, since
7766 // we will exit from the main loop below before actually storing
7767 // anything computed from that data.
7768 "cmp %w[col], w8\n" // Have we finished the last column?
7769 "bge 5f\n" // If yes, just carry on without updating the column pointer.
7770 // Not finished last column: then advance to next column.
7771 "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
7772 "5:\n"
7773
7774 // Set the LHS and RHS data pointers to the start of the columns just
7775 // computed.
7776 "mov %[lhs_ptr], %[lhs_col_ptr]\n"
7777 "mov %[rhs_ptr], %[rhs_col_ptr]\n"
7778
7779 // Load some parameters needed for the end work on current block.
7780 "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
7781 "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
7782
7783 // Determine the channel index.
7784 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
7785 "csel w3, %w[row], %w[col], eq\n"
7786
7787 // Offset the bias pointer as needed given the current row, col.
7788 "add x5, x1, x3, lsl #2\n"
7789
7790 // If there is no bias, use no offset, just address the passed zero
7791 // data.
7792
7793 "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
7794 "csel x1, x1, x5, eq\n"
7795
7796 // Load 8 bias values.
7797 "ld1 {v14.4s}, [x1], #16\n"
7798 "ld1 {v15.4s}, [x1]\n"
7799
7800 // Now that we know what LHS and RHS data the next iteration of the
7801 // main loop will need to load, we start loading the first 32 bytes of
7802 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
7803 // in the rest of the work on the current block.
7804 "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
7805 "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
7806 "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
7807 "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
7808
7809 // Perform the bias-addition.
7810 // Jump based on channel dimension.
7811 "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
7812 "bne 6f\n"
7813 // Case where channels are rows
7814 "fadd v16.4s, v16.4s, v14.4s\n"
7815 "fadd v17.4s, v17.4s, v15.4s\n"
7816 "fadd v18.4s, v18.4s, v14.4s\n"
7817 "fadd v19.4s, v19.4s, v15.4s\n"
7818 "fadd v20.4s, v20.4s, v14.4s\n"
7819 "fadd v21.4s, v21.4s, v15.4s\n"
7820 "fadd v22.4s, v22.4s, v14.4s\n"
7821 "fadd v23.4s, v23.4s, v15.4s\n"
7822 "fadd v24.4s, v24.4s, v14.4s\n"
7823 "fadd v25.4s, v25.4s, v15.4s\n"
7824 "fadd v26.4s, v26.4s, v14.4s\n"
7825 "fadd v27.4s, v27.4s, v15.4s\n"
7826 "fadd v28.4s, v28.4s, v14.4s\n"
7827 "fadd v29.4s, v29.4s, v15.4s\n"
7828 "fadd v30.4s, v30.4s, v14.4s\n"
7829 "fadd v31.4s, v31.4s, v15.4s\n"
7830 "b 7f\n"
7831
7832 "6:\n"
7833 // Case where channels are columns
7834 "dup v8.4s, v14.s[0]\n"
7835 "dup v9.4s, v14.s[1]\n"
7836 "fadd v16.4s, v16.4s, v8.4s\n"
7837 "dup v10.4s, v14.s[2]\n"
7838 "fadd v17.4s, v17.4s, v8.4s\n"
7839 "dup v11.4s, v14.s[3]\n"
7840 "fadd v18.4s, v18.4s, v9.4s\n"
7841 "dup v12.4s, v15.s[0]\n"
7842 "fadd v19.4s, v19.4s, v9.4s\n"
7843 "dup v13.4s, v15.s[1]\n"
7844 "fadd v20.4s, v20.4s, v10.4s\n"
7845 "dup v14.4s, v15.s[2]\n"
7846 "fadd v21.4s, v21.4s, v10.4s\n"
7847 "dup v15.4s, v15.s[3]\n"
7848 "fadd v22.4s, v22.4s, v11.4s\n"
7849 "fadd v23.4s, v23.4s, v11.4s\n"
7850 "fadd v24.4s, v24.4s, v12.4s\n"
7851 "fadd v25.4s, v25.4s, v12.4s\n"
7852 "fadd v26.4s, v26.4s, v13.4s\n"
7853 "fadd v27.4s, v27.4s, v13.4s\n"
7854 "fadd v28.4s, v28.4s, v14.4s\n"
7855 "fadd v29.4s, v29.4s, v14.4s\n"
7856 "fadd v30.4s, v30.4s, v15.4s\n"
7857 "fadd v31.4s, v31.4s, v15.4s\n"
7858 "7:\n"
7859
7860 // Load the clamp_min, clamp_max bounds
7861 "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
7862 "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
7863 "dup v14.4s, w2\n" // clamp_min
7864 "dup v15.4s, w3\n" // clamp_max
7865
7866 // Apply the clamp_min bound
7867 "fmax v16.4s, v16.4s, v14.4s\n"
7868 "fmax v17.4s, v17.4s, v14.4s\n"
7869 "fmax v18.4s, v18.4s, v14.4s\n"
7870 "fmax v19.4s, v19.4s, v14.4s\n"
7871 "fmax v20.4s, v20.4s, v14.4s\n"
7872 "fmax v21.4s, v21.4s, v14.4s\n"
7873 "fmax v22.4s, v22.4s, v14.4s\n"
7874 "fmax v23.4s, v23.4s, v14.4s\n"
7875 "fmax v24.4s, v24.4s, v14.4s\n"
7876 "fmax v25.4s, v25.4s, v14.4s\n"
7877 "fmax v26.4s, v26.4s, v14.4s\n"
7878 "fmax v27.4s, v27.4s, v14.4s\n"
7879 "fmax v28.4s, v28.4s, v14.4s\n"
7880 "fmax v29.4s, v29.4s, v14.4s\n"
7881 "fmax v30.4s, v30.4s, v14.4s\n"
7882 "fmax v31.4s, v31.4s, v14.4s\n"
7883
7884 // Apply the clamp_max bound
7885 "fmin v16.4s, v16.4s, v15.4s\n"
7886 "fmin v17.4s, v17.4s, v15.4s\n"
7887 "fmin v18.4s, v18.4s, v15.4s\n"
7888 "fmin v19.4s, v19.4s, v15.4s\n"
7889 "fmin v20.4s, v20.4s, v15.4s\n"
7890 "fmin v21.4s, v21.4s, v15.4s\n"
7891 "fmin v22.4s, v22.4s, v15.4s\n"
7892 "fmin v23.4s, v23.4s, v15.4s\n"
7893 "fmin v24.4s, v24.4s, v15.4s\n"
7894 "fmin v25.4s, v25.4s, v15.4s\n"
7895 "fmin v26.4s, v26.4s, v15.4s\n"
7896 "fmin v27.4s, v27.4s, v15.4s\n"
7897 "fmin v28.4s, v28.4s, v15.4s\n"
7898 "fmin v29.4s, v29.4s, v15.4s\n"
7899 "fmin v30.4s, v30.4s, v15.4s\n"
7900 "fmin v31.4s, v31.4s, v15.4s\n"
7901
7902 // Compute how much of the 8x8 block of destination 8bit values that
7903 // we have computed, fit in the destination matrix. Typically, all of
7904 // it fits, but when the destination matrix shape is not a multiple
7905 // of 8x8, there are some 8x8 blocks along the boundaries that do
7906 // not fit entirely.
7907 "sub w1, %w[dst_rows], %w[row]\n"
7908 "sub w2, %w[dst_cols], %w[col]\n"
7909 "mov w3, #8\n"
7910 "cmp w1, #8\n"
7911 // Compute w1 = how many rows of the 8x8 block fit
7912 "csel w1, w1, w3, le\n"
7913 "cmp w2, #8\n"
7914 // Compute w2 = how many cols of the 8x8 block fit
7915 "csel w2, w2, w3, le\n"
7916
7917 // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
7918 "cmp w1, w3\n"
7919 "ccmp w2, w3, 0, eq\n"
7920 // Yes, all of the 8x8 block fits, go to fast path.
7921 "beq 30f\n"
7922 // Not all of the 8x8 block fits.
7923 // Set (x3 address, x4 stride) to write to dst_tmp_buf
7924 "mov x3, %[dst_tmp_buf]\n"
7925 "mov x4, #32\n"
7926 "b 31f\n"
7927 "30:\n"
7928 // Yes, all of the 8x8 block fits.
7929 // Set (x3 address, x4 stride) to write directly to destination matrix.
7930 "mov x3, %[dst_ptr]\n"
7931 "mov x4, x11\n"
7932 "31:\n"
7933
7934 // Write our 8bit values to the destination described by
7935 // (x3 address, x4 stride).
7936 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7937 "str q16, [x3, #0]\n"
7938 "str q17, [x3, #16]\n"
7939 "add x3, x3, x4\n"
7940 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7941 RUY_MAKE_ZERO(v16)
7942 RUY_MAKE_ZERO(v17)
7943 "str q18, [x3, #0]\n"
7944 "str q19, [x3, #16]\n"
7945 "add x3, x3, x4\n"
7946 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7947 RUY_MAKE_ZERO(v18)
7948 RUY_MAKE_ZERO(v19)
7949 "str q20, [x3, #0]\n"
7950 "str q21, [x3, #16]\n"
7951 "add x3, x3, x4\n"
7952 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7953 RUY_MAKE_ZERO(v20)
7954 RUY_MAKE_ZERO(v21)
7955 "str q22, [x3, #0]\n"
7956 "str q23, [x3, #16]\n"
7957 "add x3, x3, x4\n"
7958 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7959 RUY_MAKE_ZERO(v22)
7960 RUY_MAKE_ZERO(v23)
7961 "str q24, [x3, #0]\n"
7962 "str q25, [x3, #16]\n"
7963 "add x3, x3, x4\n"
7964 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7965 RUY_MAKE_ZERO(v24)
7966 RUY_MAKE_ZERO(v25)
7967 "str q26, [x3, #0]\n"
7968 "str q27, [x3, #16]\n"
7969 "add x3, x3, x4\n"
7970 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7971 RUY_MAKE_ZERO(v26)
7972 RUY_MAKE_ZERO(v27)
7973 "str q28, [x3, #0]\n"
7974 "str q29, [x3, #16]\n"
7975 "add x3, x3, x4\n"
7976 RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
7977 RUY_MAKE_ZERO(v28)
7978 RUY_MAKE_ZERO(v29)
7979 "str q30, [x3, #0]\n"
7980 "str q31, [x3, #16]\n"
7981 RUY_MAKE_ZERO(v30)
7982 RUY_MAKE_ZERO(v31)
7983
7984 // If all of the 8x8 block fits, we just finished writing it to the
7985 // destination, so we skip the next part.
7986 "beq 41f\n"
7987 // Not all of the 8x8 block fits in the destination matrix. We just
7988 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
7989 // it to copy into the destination matrix the part that fits.
7990 "mov x3, %[dst_tmp_buf]\n"
7991 "mov x4, %[dst_ptr]\n"
7992 "mov w6, #0\n"
7993 "50:\n"
7994 RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
7995 "mov w5, #0\n"
7996 "51:\n"
7997 "ldr w7, [x3, x5, lsl #2]\n"
7998 "str w7, [x4, x5, lsl #2]\n"
7999 "add w5, w5, #1\n"
8000 "cmp w5, w1\n"
8001 "blt 51b\n"
8002 "add w6, w6, #1\n"
8003 "add x3, x3, #32\n"
8004 "add x4, x4, x11\n"
8005 "cmp w6, w2\n"
8006 "blt 50b\n"
8007 "41:\n"
8008 "add %[dst_ptr], %[dst_ptr], #32\n"
8009 // At this point we have completely finished writing values to the
8010 // destination matrix for the current block.
8011
8012 // Reload some params --- we had used x5 -- x7 for a few other things
8013 // since the last time we had loaded them.
8014 "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
8015 "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
8016 "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
8017
8018 // Move to the next block of the destination matrix, for the next iter
8019 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
8020 // been updated earlier.
8021 // Have we reached the end row?
8022 "cmp %w[row], w7\n"
8023 "beq 20f\n" // yes, end row.
8024 // Not end row. Move to the next row.
8025 "add %w[row], %w[row], #8\n"
8026 "b 21f\n"
8027 "20:\n"
8028 // Was already at end row.
8029 "mov %w[row], w6\n" // Move back to first row.
8030 "add %w[col], %w[col], #8\n" // Move to the next column.
8031 "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
8032 "mov %[dst_ptr], %[dst_col_ptr]\n"
8033 "21:\n"
8034
8035 // Main loop exit condition: have we hit the end column?
8036 "cmp %w[col], w8\n"
8037
8038 // w1 is the number of levels of depth that remain to load
8039 // LHS and RHS data for. Corresponding to the initial ld1 instructions
8040 // above, this is currently depth - 1.
8041 "sub w1, w12, #1\n"
8042
8043 "ble 1b\n"
8044
8045 // clang-format on
8046
8047 : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
8048 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
8049 [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
8050 : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
8051 [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
8052 : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
8053 "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
8054 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
8055 "v26", "v27", "v28", "v29", "v30", "v31");
8056 }
8057 #undef RUY_OFFSET_BIAS
8058 #undef RUY_OFFSET_FLAGS
8059 #undef RUY_OFFSET_LHS_BASE_PTR
8060 #undef RUY_OFFSET_CLAMP_MIN
8061 #undef RUY_OFFSET_CLAMP_MAX
8062 #undef RUY_OFFSET_START_ROW
8063 #undef RUY_OFFSET_LAST_ROW
8064 #undef RUY_OFFSET_LAST_COL
8065 #undef RUY_OFFSET_LHS_STRIDE
8066 #undef RUY_OFFSET_RHS_STRIDE
8067 #undef RUY_OFFSET_DST_STRIDE
8068 #undef RUY_OFFSET_DEPTH
8069 #undef RUY_OFFSET_START_COL
8070 #undef RUY_OFFSET_RHS_BASE_PTR
8071 #undef RUY_OFFSET_DST_BASE_PTR
8072
8073 #endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
8074
8075 } // namespace ruy
8076