1 // Auto-generated file. Do not edit!
2 // Template: src/qs8-gavgpool/multipass-neon.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/gavgpool.h>
15 #include <xnnpack/math.h>
16
17
xnn_qs8_gavgpool_minmax_ukernel_7p7x__neon_c24_acc2(size_t rows,size_t channels,const int8_t * input,size_t input_stride,const int8_t * zero,int32_t * buffer,int8_t * output,const union xnn_qs8_avgpool_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qs8_gavgpool_minmax_ukernel_7p7x__neon_c24_acc2(
19 size_t rows,
20 size_t channels,
21 const int8_t* input,
22 size_t input_stride,
23 const int8_t* zero,
24 int32_t* buffer,
25 int8_t* output,
26 const union xnn_qs8_avgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
27 {
28 assert(rows > 7);
29 assert(channels != 0);
30
31 const int8_t* i0 = input;
32 const int8_t* i1 = (const int8_t*) ((uintptr_t) i0 + input_stride);
33 const int8_t* i2 = (const int8_t*) ((uintptr_t) i1 + input_stride);
34 const int8_t* i3 = (const int8_t*) ((uintptr_t) i2 + input_stride);
35 const int8_t* i4 = (const int8_t*) ((uintptr_t) i3 + input_stride);
36 const int8_t* i5 = (const int8_t*) ((uintptr_t) i4 + input_stride);
37 const int8_t* i6 = (const int8_t*) ((uintptr_t) i5 + input_stride);
38 const size_t input_increment = 7 * input_stride - round_up_po2(channels, 8);
39
40 const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
41 int32_t* b = buffer;
42 size_t c = channels;
43 for (; c >= 24; c -= 24) {
44 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
45 const int8x8_t vi0x89ABCDEF = vld1_s8(i0); i0 += 8;
46 const int8x8_t vi0xGHIJKLMN = vld1_s8(i0); i0 += 8;
47 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
48 const int8x8_t vi1x89ABCDEF = vld1_s8(i1); i1 += 8;
49 const int8x8_t vi1xGHIJKLMN = vld1_s8(i1); i1 += 8;
50 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
51 const int8x8_t vi2x89ABCDEF = vld1_s8(i2); i2 += 8;
52 const int8x8_t vi2xGHIJKLMN = vld1_s8(i2); i2 += 8;
53 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
54 const int8x8_t vi3x89ABCDEF = vld1_s8(i3); i3 += 8;
55 const int8x8_t vi3xGHIJKLMN = vld1_s8(i3); i3 += 8;
56 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
57 const int8x8_t vi4x89ABCDEF = vld1_s8(i4); i4 += 8;
58 const int8x8_t vi4xGHIJKLMN = vld1_s8(i4); i4 += 8;
59 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
60 const int8x8_t vi5x89ABCDEF = vld1_s8(i5); i5 += 8;
61 const int8x8_t vi5xGHIJKLMN = vld1_s8(i5); i5 += 8;
62 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
63 const int8x8_t vi6x89ABCDEF = vld1_s8(i6); i6 += 8;
64 const int8x8_t vi6xGHIJKLMN = vld1_s8(i6); i6 += 8;
65
66 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
67 int16x8_t vacc0x89ABCDEF = vaddl_s8(vi0x89ABCDEF, vi1x89ABCDEF);
68 int16x8_t vacc0xGHIJKLMN = vaddl_s8(vi0xGHIJKLMN, vi1xGHIJKLMN);
69 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
70 int16x8_t vacc1x89ABCDEF = vaddl_s8(vi2x89ABCDEF, vi3x89ABCDEF);
71 int16x8_t vacc1xGHIJKLMN = vaddl_s8(vi2xGHIJKLMN, vi3xGHIJKLMN);
72
73 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
74 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi4x89ABCDEF);
75 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi4xGHIJKLMN);
76 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
77 vacc1x89ABCDEF = vaddw_s8(vacc1x89ABCDEF, vi5x89ABCDEF);
78 vacc1xGHIJKLMN = vaddw_s8(vacc1xGHIJKLMN, vi5xGHIJKLMN);
79 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
80 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi6x89ABCDEF);
81 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi6xGHIJKLMN);
82
83 // Add up all accumulators to vacc0x0123456789ABCDEFGHIJKLMN
84 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
85 vacc0x89ABCDEF = vaddq_s16(vacc0x89ABCDEF, vacc1x89ABCDEF);
86 vacc0xGHIJKLMN = vaddq_s16(vacc0xGHIJKLMN, vacc1xGHIJKLMN);
87
88 const int32x4_t vacc0123 = vaddw_s16(vbias, vget_low_s16(vacc0x01234567));
89 const int32x4_t vacc4567 = vaddw_s16(vbias, vget_high_s16(vacc0x01234567));
90 const int32x4_t vacc89AB = vaddw_s16(vbias, vget_low_s16(vacc0x89ABCDEF));
91 const int32x4_t vaccCDEF = vaddw_s16(vbias, vget_high_s16(vacc0x89ABCDEF));
92 const int32x4_t vaccGHIJ = vaddw_s16(vbias, vget_low_s16(vacc0xGHIJKLMN));
93 const int32x4_t vaccKLMN = vaddw_s16(vbias, vget_high_s16(vacc0xGHIJKLMN));
94
95 vst1q_s32(b, vacc0123); b += 4;
96 vst1q_s32(b, vacc4567); b += 4;
97 vst1q_s32(b, vacc89AB); b += 4;
98 vst1q_s32(b, vaccCDEF); b += 4;
99 vst1q_s32(b, vaccGHIJ); b += 4;
100 vst1q_s32(b, vaccKLMN); b += 4;
101 }
102 if XNN_UNLIKELY(c != 0) {
103 do {
104 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
105 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
106 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
107 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
108 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
109 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
110 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
111
112 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
113 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
114
115 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
116 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
117 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
118
119 // Add up all accumulators to vacc0x01234567
120 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
121
122 const int32x4_t vacc0123 = vaddw_s16(vbias, vget_low_s16(vacc0x01234567));
123 const int32x4_t vacc4567 = vaddw_s16(vbias, vget_high_s16(vacc0x01234567));
124
125 vst1q_s32(b, vacc0123); b += 4;
126 vst1q_s32(b, vacc4567); b += 4;
127
128 c = doz(c, 8);
129 } while (c != 0);
130 }
131
132 for (rows -= 7; rows > 7; rows -= 7) {
133 i0 = (const int8_t*) ((uintptr_t) i0 + input_increment);
134 i1 = (const int8_t*) ((uintptr_t) i1 + input_increment);
135 i2 = (const int8_t*) ((uintptr_t) i2 + input_increment);
136 i3 = (const int8_t*) ((uintptr_t) i3 + input_increment);
137 i4 = (const int8_t*) ((uintptr_t) i4 + input_increment);
138 i5 = (const int8_t*) ((uintptr_t) i5 + input_increment);
139 i6 = (const int8_t*) ((uintptr_t) i6 + input_increment);
140
141 int32_t* b = buffer;
142 size_t c = channels;
143 for (; c >= 24; c -= 24) {
144 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
145 const int8x8_t vi0x89ABCDEF = vld1_s8(i0); i0 += 8;
146 const int8x8_t vi0xGHIJKLMN = vld1_s8(i0); i0 += 8;
147 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
148 const int8x8_t vi1x89ABCDEF = vld1_s8(i1); i1 += 8;
149 const int8x8_t vi1xGHIJKLMN = vld1_s8(i1); i1 += 8;
150 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
151 const int8x8_t vi2x89ABCDEF = vld1_s8(i2); i2 += 8;
152 const int8x8_t vi2xGHIJKLMN = vld1_s8(i2); i2 += 8;
153 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
154 const int8x8_t vi3x89ABCDEF = vld1_s8(i3); i3 += 8;
155 const int8x8_t vi3xGHIJKLMN = vld1_s8(i3); i3 += 8;
156 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
157 const int8x8_t vi4x89ABCDEF = vld1_s8(i4); i4 += 8;
158 const int8x8_t vi4xGHIJKLMN = vld1_s8(i4); i4 += 8;
159 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
160 const int8x8_t vi5x89ABCDEF = vld1_s8(i5); i5 += 8;
161 const int8x8_t vi5xGHIJKLMN = vld1_s8(i5); i5 += 8;
162 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
163 const int8x8_t vi6x89ABCDEF = vld1_s8(i6); i6 += 8;
164 const int8x8_t vi6xGHIJKLMN = vld1_s8(i6); i6 += 8;
165
166 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
167 int16x8_t vacc0x89ABCDEF = vaddl_s8(vi0x89ABCDEF, vi1x89ABCDEF);
168 int16x8_t vacc0xGHIJKLMN = vaddl_s8(vi0xGHIJKLMN, vi1xGHIJKLMN);
169 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
170 int16x8_t vacc1x89ABCDEF = vaddl_s8(vi2x89ABCDEF, vi3x89ABCDEF);
171 int16x8_t vacc1xGHIJKLMN = vaddl_s8(vi2xGHIJKLMN, vi3xGHIJKLMN);
172
173 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
174 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi4x89ABCDEF);
175 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi4xGHIJKLMN);
176 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
177 vacc1x89ABCDEF = vaddw_s8(vacc1x89ABCDEF, vi5x89ABCDEF);
178 vacc1xGHIJKLMN = vaddw_s8(vacc1xGHIJKLMN, vi5xGHIJKLMN);
179 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
180 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi6x89ABCDEF);
181 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi6xGHIJKLMN);
182
183 // Add up all accumulators to vacc0x0123456789ABCDEFGHIJKLMN
184 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
185 vacc0x89ABCDEF = vaddq_s16(vacc0x89ABCDEF, vacc1x89ABCDEF);
186 vacc0xGHIJKLMN = vaddq_s16(vacc0xGHIJKLMN, vacc1xGHIJKLMN);
187
188 int32x4_t vacc0123 = vld1q_s32(b);
189 int32x4_t vacc4567 = vld1q_s32(b + 4);
190 int32x4_t vacc89AB = vld1q_s32(b + 8);
191 int32x4_t vaccCDEF = vld1q_s32(b + 12);
192 int32x4_t vaccGHIJ = vld1q_s32(b + 16);
193 int32x4_t vaccKLMN = vld1q_s32(b + 20);
194
195 vacc0123 = vaddw_s16(vacc0123, vget_low_s16(vacc0x01234567));
196 vacc4567 = vaddw_s16(vacc4567, vget_high_s16(vacc0x01234567));
197 vacc89AB = vaddw_s16(vacc89AB, vget_low_s16(vacc0x89ABCDEF));
198 vaccCDEF = vaddw_s16(vaccCDEF, vget_high_s16(vacc0x89ABCDEF));
199 vaccGHIJ = vaddw_s16(vaccGHIJ, vget_low_s16(vacc0xGHIJKLMN));
200 vaccKLMN = vaddw_s16(vaccKLMN, vget_high_s16(vacc0xGHIJKLMN));
201
202 vst1q_s32(b, vacc0123); b += 4;
203 vst1q_s32(b, vacc4567); b += 4;
204 vst1q_s32(b, vacc89AB); b += 4;
205 vst1q_s32(b, vaccCDEF); b += 4;
206 vst1q_s32(b, vaccGHIJ); b += 4;
207 vst1q_s32(b, vaccKLMN); b += 4;
208 }
209 if XNN_UNLIKELY(c != 0) {
210 do {
211 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
212 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
213 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
214 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
215 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
216 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
217 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
218
219 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
220 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
221
222 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
223 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
224 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
225
226 // Add up all accumulators to vacc0x01234567
227 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
228
229 int32x4_t vacc0123 = vld1q_s32(b);
230 int32x4_t vacc4567 = vld1q_s32(b + 4);
231
232 vacc0123 = vaddw_s16(vacc0123, vget_low_s16(vacc0x01234567));
233 vacc4567 = vaddw_s16(vacc4567, vget_high_s16(vacc0x01234567));
234
235 vst1q_s32(b, vacc0123); b += 4;
236 vst1q_s32(b, vacc4567); b += 4;
237
238 c = doz(c, 8);
239 } while (c != 0);
240 }
241 }
242
243 i0 = (const int8_t*) ((uintptr_t) i0 + input_increment);
244 i1 = (const int8_t*) ((uintptr_t) i1 + input_increment);
245 if XNN_UNPREDICTABLE(rows < 2) {
246 i1 = zero;
247 }
248 i2 = (const int8_t*) ((uintptr_t) i2 + input_increment);
249 if XNN_UNPREDICTABLE(rows <= 2) {
250 i2 = zero;
251 }
252 i3 = (const int8_t*) ((uintptr_t) i3 + input_increment);
253 if XNN_UNPREDICTABLE(rows < 4) {
254 i3 = zero;
255 }
256 i4 = (const int8_t*) ((uintptr_t) i4 + input_increment);
257 if XNN_UNPREDICTABLE(rows <= 4) {
258 i4 = zero;
259 }
260 i5 = (const int8_t*) ((uintptr_t) i5 + input_increment);
261 if XNN_UNPREDICTABLE(rows < 6) {
262 i5 = zero;
263 }
264 i6 = (const int8_t*) ((uintptr_t) i6 + input_increment);
265 if XNN_UNPREDICTABLE(rows <= 6) {
266 i6 = zero;
267 }
268
269 #if XNN_ARCH_ARM64
270 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
271 #else
272 const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
273 #endif
274 const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
275 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
276 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->neon.output_min);
277 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->neon.output_max);
278 while (channels >= 24) {
279 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
280 const int8x8_t vi0x89ABCDEF = vld1_s8(i0); i0 += 8;
281 const int8x8_t vi0xGHIJKLMN = vld1_s8(i0); i0 += 8;
282 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
283 const int8x8_t vi1x89ABCDEF = vld1_s8(i1); i1 += 8;
284 const int8x8_t vi1xGHIJKLMN = vld1_s8(i1); i1 += 8;
285 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
286 const int8x8_t vi2x89ABCDEF = vld1_s8(i2); i2 += 8;
287 const int8x8_t vi2xGHIJKLMN = vld1_s8(i2); i2 += 8;
288 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
289 const int8x8_t vi3x89ABCDEF = vld1_s8(i3); i3 += 8;
290 const int8x8_t vi3xGHIJKLMN = vld1_s8(i3); i3 += 8;
291 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
292 const int8x8_t vi4x89ABCDEF = vld1_s8(i4); i4 += 8;
293 const int8x8_t vi4xGHIJKLMN = vld1_s8(i4); i4 += 8;
294 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
295 const int8x8_t vi5x89ABCDEF = vld1_s8(i5); i5 += 8;
296 const int8x8_t vi5xGHIJKLMN = vld1_s8(i5); i5 += 8;
297 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
298 const int8x8_t vi6x89ABCDEF = vld1_s8(i6); i6 += 8;
299 const int8x8_t vi6xGHIJKLMN = vld1_s8(i6); i6 += 8;
300
301 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
302 int16x8_t vacc0x89ABCDEF = vaddl_s8(vi0x89ABCDEF, vi1x89ABCDEF);
303 int16x8_t vacc0xGHIJKLMN = vaddl_s8(vi0xGHIJKLMN, vi1xGHIJKLMN);
304 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
305 int16x8_t vacc1x89ABCDEF = vaddl_s8(vi2x89ABCDEF, vi3x89ABCDEF);
306 int16x8_t vacc1xGHIJKLMN = vaddl_s8(vi2xGHIJKLMN, vi3xGHIJKLMN);
307
308 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
309 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi4x89ABCDEF);
310 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi4xGHIJKLMN);
311 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
312 vacc1x89ABCDEF = vaddw_s8(vacc1x89ABCDEF, vi5x89ABCDEF);
313 vacc1xGHIJKLMN = vaddw_s8(vacc1xGHIJKLMN, vi5xGHIJKLMN);
314 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
315 vacc0x89ABCDEF = vaddw_s8(vacc0x89ABCDEF, vi6x89ABCDEF);
316 vacc0xGHIJKLMN = vaddw_s8(vacc0xGHIJKLMN, vi6xGHIJKLMN);
317
318 // Add up all accumulators to vacc0x0123456789ABCDEFGHIJKLMN
319 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
320 vacc0x89ABCDEF = vaddq_s16(vacc0x89ABCDEF, vacc1x89ABCDEF);
321 vacc0xGHIJKLMN = vaddq_s16(vacc0xGHIJKLMN, vacc1xGHIJKLMN);
322
323 int32x4_t vacc0123 = vld1q_s32(buffer); buffer += 4;
324 int32x4_t vacc4567 = vld1q_s32(buffer); buffer += 4;
325 int32x4_t vacc89AB = vld1q_s32(buffer); buffer += 4;
326 int32x4_t vaccCDEF = vld1q_s32(buffer); buffer += 4;
327 int32x4_t vaccGHIJ = vld1q_s32(buffer); buffer += 4;
328 int32x4_t vaccKLMN = vld1q_s32(buffer); buffer += 4;
329
330 vacc0123 = vaddw_s16(vacc0123, vget_low_s16(vacc0x01234567));
331 vacc4567 = vaddw_s16(vacc4567, vget_high_s16(vacc0x01234567));
332 vacc89AB = vaddw_s16(vacc89AB, vget_low_s16(vacc0x89ABCDEF));
333 vaccCDEF = vaddw_s16(vaccCDEF, vget_high_s16(vacc0x89ABCDEF));
334 vaccGHIJ = vaddw_s16(vaccGHIJ, vget_low_s16(vacc0xGHIJKLMN));
335 vaccKLMN = vaddw_s16(vaccKLMN, vget_high_s16(vacc0xGHIJKLMN));
336
337 const int32x4_t vsgnacc0123 = vreinterpretq_s32_u32(vcltq_s32(vacc0123, vmovq_n_s32(0)));
338 const int32x4_t vsgnacc4567 = vreinterpretq_s32_u32(vcltq_s32(vacc4567, vmovq_n_s32(0)));
339 const int32x4_t vsgnacc89AB = vreinterpretq_s32_u32(vcltq_s32(vacc89AB, vmovq_n_s32(0)));
340 const int32x4_t vsgnaccCDEF = vreinterpretq_s32_u32(vcltq_s32(vaccCDEF, vmovq_n_s32(0)));
341 const int32x4_t vsgnaccGHIJ = vreinterpretq_s32_u32(vcltq_s32(vaccGHIJ, vmovq_n_s32(0)));
342 const int32x4_t vsgnaccKLMN = vreinterpretq_s32_u32(vcltq_s32(vaccKLMN, vmovq_n_s32(0)));
343
344 #if XNN_ARCH_ARM64
345 const int64x2_t vprod01 = vmull_s32(vget_low_s32(vacc0123), vget_low_s32(vmultiplier));
346 const int64x2_t vprod23 = vmull_high_s32(vacc0123, vmultiplier);
347 const int64x2_t vprod45 = vmull_s32(vget_low_s32(vacc4567), vget_low_s32(vmultiplier));
348 const int64x2_t vprod67 = vmull_high_s32(vacc4567, vmultiplier);
349 const int64x2_t vprod89 = vmull_s32(vget_low_s32(vacc89AB), vget_low_s32(vmultiplier));
350 const int64x2_t vprodAB = vmull_high_s32(vacc89AB, vmultiplier);
351 const int64x2_t vprodCD = vmull_s32(vget_low_s32(vaccCDEF), vget_low_s32(vmultiplier));
352 const int64x2_t vprodEF = vmull_high_s32(vaccCDEF, vmultiplier);
353 const int64x2_t vprodGH = vmull_s32(vget_low_s32(vaccGHIJ), vget_low_s32(vmultiplier));
354 const int64x2_t vprodIJ = vmull_high_s32(vaccGHIJ, vmultiplier);
355 const int64x2_t vprodKL = vmull_s32(vget_low_s32(vaccKLMN), vget_low_s32(vmultiplier));
356 const int64x2_t vprodMN = vmull_high_s32(vaccKLMN, vmultiplier);
357
358 const int64x2_t vadjprod01 = vaddw_s32(vprod01, vget_low_s32(vsgnacc0123));
359 const int64x2_t vadjprod23 = vaddw_high_s32(vprod23, vsgnacc0123);
360 const int64x2_t vadjprod45 = vaddw_s32(vprod45, vget_low_s32(vsgnacc4567));
361 const int64x2_t vadjprod67 = vaddw_high_s32(vprod67, vsgnacc4567);
362 const int64x2_t vadjprod89 = vaddw_s32(vprod89, vget_low_s32(vsgnacc89AB));
363 const int64x2_t vadjprodAB = vaddw_high_s32(vprodAB, vsgnacc89AB);
364 const int64x2_t vadjprodCD = vaddw_s32(vprodCD, vget_low_s32(vsgnaccCDEF));
365 const int64x2_t vadjprodEF = vaddw_high_s32(vprodEF, vsgnaccCDEF);
366 const int64x2_t vadjprodGH = vaddw_s32(vprodGH, vget_low_s32(vsgnaccGHIJ));
367 const int64x2_t vadjprodIJ = vaddw_high_s32(vprodIJ, vsgnaccGHIJ);
368 const int64x2_t vadjprodKL = vaddw_s32(vprodKL, vget_low_s32(vsgnaccKLMN));
369 const int64x2_t vadjprodMN = vaddw_high_s32(vprodMN, vsgnaccKLMN);
370 #else
371 const int64x2_t vprod01 = vmull_s32(vget_low_s32(vacc0123), vmultiplier);
372 const int64x2_t vprod23 = vmull_s32(vget_high_s32(vacc0123), vmultiplier);
373 const int64x2_t vprod45 = vmull_s32(vget_low_s32(vacc4567), vmultiplier);
374 const int64x2_t vprod67 = vmull_s32(vget_high_s32(vacc4567), vmultiplier);
375 const int64x2_t vprod89 = vmull_s32(vget_low_s32(vacc89AB), vmultiplier);
376 const int64x2_t vprodAB = vmull_s32(vget_high_s32(vacc89AB), vmultiplier);
377 const int64x2_t vprodCD = vmull_s32(vget_low_s32(vaccCDEF), vmultiplier);
378 const int64x2_t vprodEF = vmull_s32(vget_high_s32(vaccCDEF), vmultiplier);
379 const int64x2_t vprodGH = vmull_s32(vget_low_s32(vaccGHIJ), vmultiplier);
380 const int64x2_t vprodIJ = vmull_s32(vget_high_s32(vaccGHIJ), vmultiplier);
381 const int64x2_t vprodKL = vmull_s32(vget_low_s32(vaccKLMN), vmultiplier);
382 const int64x2_t vprodMN = vmull_s32(vget_high_s32(vaccKLMN), vmultiplier);
383
384 const int64x2_t vadjprod01 = vaddw_s32(vprod01, vget_low_s32(vsgnacc0123));
385 const int64x2_t vadjprod23 = vaddw_s32(vprod23, vget_high_s32(vsgnacc0123));
386 const int64x2_t vadjprod45 = vaddw_s32(vprod45, vget_low_s32(vsgnacc4567));
387 const int64x2_t vadjprod67 = vaddw_s32(vprod67, vget_high_s32(vsgnacc4567));
388 const int64x2_t vadjprod89 = vaddw_s32(vprod89, vget_low_s32(vsgnacc89AB));
389 const int64x2_t vadjprodAB = vaddw_s32(vprodAB, vget_high_s32(vsgnacc89AB));
390 const int64x2_t vadjprodCD = vaddw_s32(vprodCD, vget_low_s32(vsgnaccCDEF));
391 const int64x2_t vadjprodEF = vaddw_s32(vprodEF, vget_high_s32(vsgnaccCDEF));
392 const int64x2_t vadjprodGH = vaddw_s32(vprodGH, vget_low_s32(vsgnaccGHIJ));
393 const int64x2_t vadjprodIJ = vaddw_s32(vprodIJ, vget_high_s32(vsgnaccGHIJ));
394 const int64x2_t vadjprodKL = vaddw_s32(vprodKL, vget_low_s32(vsgnaccKLMN));
395 const int64x2_t vadjprodMN = vaddw_s32(vprodMN, vget_high_s32(vsgnaccKLMN));
396 #endif
397
398 const int64x2_t vacc01 = vrshlq_s64(vadjprod01, vleft_shift);
399 const int64x2_t vacc23 = vrshlq_s64(vadjprod23, vleft_shift);
400 const int64x2_t vacc45 = vrshlq_s64(vadjprod45, vleft_shift);
401 const int64x2_t vacc67 = vrshlq_s64(vadjprod67, vleft_shift);
402 const int64x2_t vacc89 = vrshlq_s64(vadjprod89, vleft_shift);
403 const int64x2_t vaccAB = vrshlq_s64(vadjprodAB, vleft_shift);
404 const int64x2_t vaccCD = vrshlq_s64(vadjprodCD, vleft_shift);
405 const int64x2_t vaccEF = vrshlq_s64(vadjprodEF, vleft_shift);
406 const int64x2_t vaccGH = vrshlq_s64(vadjprodGH, vleft_shift);
407 const int64x2_t vaccIJ = vrshlq_s64(vadjprodIJ, vleft_shift);
408 const int64x2_t vaccKL = vrshlq_s64(vadjprodKL, vleft_shift);
409 const int64x2_t vaccMN = vrshlq_s64(vadjprodMN, vleft_shift);
410
411 #if XNN_ARCH_ARM64
412 vacc0123 = vuzp1q_s32(vreinterpretq_s32_s64(vacc01), vreinterpretq_s32_s64(vacc23));
413 vacc4567 = vuzp1q_s32(vreinterpretq_s32_s64(vacc45), vreinterpretq_s32_s64(vacc67));
414 vacc89AB = vuzp1q_s32(vreinterpretq_s32_s64(vacc89), vreinterpretq_s32_s64(vaccAB));
415 vaccCDEF = vuzp1q_s32(vreinterpretq_s32_s64(vaccCD), vreinterpretq_s32_s64(vaccEF));
416 vaccGHIJ = vuzp1q_s32(vreinterpretq_s32_s64(vaccGH), vreinterpretq_s32_s64(vaccIJ));
417 vaccKLMN = vuzp1q_s32(vreinterpretq_s32_s64(vaccKL), vreinterpretq_s32_s64(vaccMN));
418
419 const int16x8_t vacc01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0123), vacc4567), voutput_zero_point);
420 const int16x8_t vacc89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc89AB), vaccCDEF), voutput_zero_point);
421 const int16x8_t vaccGHIJKLMN = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vaccGHIJ), vaccKLMN), voutput_zero_point);
422
423 int8x16_t vout0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc01234567), vacc89ABCDEF);
424 int8x8_t voutGHIJKLMN = vqmovn_s16(vaccGHIJKLMN);
425 #else
426 vacc0123 = vcombine_s32(vmovn_s64(vacc01), vmovn_s64(vacc23));
427 vacc4567 = vcombine_s32(vmovn_s64(vacc45), vmovn_s64(vacc67));
428 vacc89AB = vcombine_s32(vmovn_s64(vacc89), vmovn_s64(vaccAB));
429 vaccCDEF = vcombine_s32(vmovn_s64(vaccCD), vmovn_s64(vaccEF));
430 vaccGHIJ = vcombine_s32(vmovn_s64(vaccGH), vmovn_s64(vaccIJ));
431 vaccKLMN = vcombine_s32(vmovn_s64(vaccKL), vmovn_s64(vaccMN));
432
433 const int16x8_t vacc01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0123), vqmovn_s32(vacc4567)), voutput_zero_point);
434 const int16x8_t vacc89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc89AB), vqmovn_s32(vaccCDEF)), voutput_zero_point);
435 const int16x8_t vaccGHIJKLMN = vqaddq_s16(vcombine_s16(vqmovn_s32(vaccGHIJ), vqmovn_s32(vaccKLMN)), voutput_zero_point);
436
437 int8x16_t vout0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc01234567), vqmovn_s16(vacc89ABCDEF));
438 int8x8_t voutGHIJKLMN = vqmovn_s16(vaccGHIJKLMN);
439 #endif
440
441 vout0123456789ABCDEF = vmaxq_s8(vout0123456789ABCDEF, voutput_min);
442 voutGHIJKLMN = vmax_s8(voutGHIJKLMN, vget_low_s8(voutput_min));
443
444 vout0123456789ABCDEF = vminq_s8(vout0123456789ABCDEF, voutput_max);
445 voutGHIJKLMN = vmin_s8(voutGHIJKLMN, vget_low_s8(voutput_max));
446
447 vst1q_s8(output, vout0123456789ABCDEF); output += 16;
448 vst1_s8(output, voutGHIJKLMN); output += 8;
449
450 channels -= 24;
451 }
452 if XNN_UNLIKELY(channels != 0) {
453 do {
454 const int8x8_t vi0x01234567 = vld1_s8(i0); i0 += 8;
455 const int8x8_t vi1x01234567 = vld1_s8(i1); i1 += 8;
456 const int8x8_t vi2x01234567 = vld1_s8(i2); i2 += 8;
457 const int8x8_t vi3x01234567 = vld1_s8(i3); i3 += 8;
458 const int8x8_t vi4x01234567 = vld1_s8(i4); i4 += 8;
459 const int8x8_t vi5x01234567 = vld1_s8(i5); i5 += 8;
460 const int8x8_t vi6x01234567 = vld1_s8(i6); i6 += 8;
461
462 int16x8_t vacc0x01234567 = vaddl_s8(vi0x01234567, vi1x01234567);
463 int16x8_t vacc1x01234567 = vaddl_s8(vi2x01234567, vi3x01234567);
464
465 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi4x01234567);
466 vacc1x01234567 = vaddw_s8(vacc1x01234567, vi5x01234567);
467 vacc0x01234567 = vaddw_s8(vacc0x01234567, vi6x01234567);
468
469 // Add up all accumulators to vacc0x01234567
470 vacc0x01234567 = vaddq_s16(vacc0x01234567, vacc1x01234567);
471
472 int32x4_t vacc0123 = vld1q_s32(buffer); buffer += 4;
473 int32x4_t vacc4567 = vld1q_s32(buffer); buffer += 4;
474
475 vacc0123 = vaddw_s16(vacc0123, vget_low_s16(vacc0x01234567));
476 vacc4567 = vaddw_s16(vacc4567, vget_high_s16(vacc0x01234567));
477
478 const int32x4_t vsgnacc0123 = vreinterpretq_s32_u32(vcltq_s32(vacc0123, vmovq_n_s32(0)));
479 const int32x4_t vsgnacc4567 = vreinterpretq_s32_u32(vcltq_s32(vacc4567, vmovq_n_s32(0)));
480
481 #if XNN_ARCH_ARM64
482 const int64x2_t vprod01 = vmull_s32(vget_low_s32(vacc0123), vget_low_s32(vmultiplier));
483 const int64x2_t vprod23 = vmull_high_s32(vacc0123, vmultiplier);
484 const int64x2_t vprod45 = vmull_s32(vget_low_s32(vacc4567), vget_low_s32(vmultiplier));
485 const int64x2_t vprod67 = vmull_high_s32(vacc4567, vmultiplier);
486
487 const int64x2_t vadjprod01 = vaddw_s32(vprod01, vget_low_s32(vsgnacc0123));
488 const int64x2_t vadjprod23 = vaddw_high_s32(vprod23, vsgnacc0123);
489 const int64x2_t vadjprod45 = vaddw_s32(vprod45, vget_low_s32(vsgnacc4567));
490 const int64x2_t vadjprod67 = vaddw_high_s32(vprod67, vsgnacc4567);
491 #else
492 const int64x2_t vprod01 = vmull_s32(vget_low_s32(vacc0123), vmultiplier);
493 const int64x2_t vprod23 = vmull_s32(vget_high_s32(vacc0123), vmultiplier);
494 const int64x2_t vprod45 = vmull_s32(vget_low_s32(vacc4567), vmultiplier);
495 const int64x2_t vprod67 = vmull_s32(vget_high_s32(vacc4567), vmultiplier);
496
497 const int64x2_t vadjprod01 = vaddw_s32(vprod01, vget_low_s32(vsgnacc0123));
498 const int64x2_t vadjprod23 = vaddw_s32(vprod23, vget_high_s32(vsgnacc0123));
499 const int64x2_t vadjprod45 = vaddw_s32(vprod45, vget_low_s32(vsgnacc4567));
500 const int64x2_t vadjprod67 = vaddw_s32(vprod67, vget_high_s32(vsgnacc4567));
501 #endif
502
503 const int64x2_t vacc01 = vrshlq_s64(vadjprod01, vleft_shift);
504 const int64x2_t vacc23 = vrshlq_s64(vadjprod23, vleft_shift);
505 const int64x2_t vacc45 = vrshlq_s64(vadjprod45, vleft_shift);
506 const int64x2_t vacc67 = vrshlq_s64(vadjprod67, vleft_shift);
507
508 #if XNN_ARCH_ARM64
509 vacc0123 = vuzp1q_s32(vreinterpretq_s32_s64(vacc01), vreinterpretq_s32_s64(vacc23));
510 vacc4567 = vuzp1q_s32(vreinterpretq_s32_s64(vacc45), vreinterpretq_s32_s64(vacc67));
511
512 const int16x8_t vacc01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0123), vacc4567), voutput_zero_point);
513
514 int8x8_t vout01234567 = vqmovn_s16(vacc01234567);
515 #else
516 vacc0123 = vcombine_s32(vmovn_s64(vacc01), vmovn_s64(vacc23));
517 vacc4567 = vcombine_s32(vmovn_s64(vacc45), vmovn_s64(vacc67));
518
519 const int16x8_t vacc01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0123), vqmovn_s32(vacc4567)), voutput_zero_point);
520
521 int8x8_t vout01234567 = vqmovn_s16(vacc01234567);
522 #endif
523
524 vout01234567 = vmax_s8(vout01234567, vget_low_s8(voutput_min));
525 vout01234567 = vmin_s8(vout01234567, vget_low_s8(voutput_max));
526
527 if XNN_LIKELY(channels >= 8) {
528 vst1_s8(output, vout01234567); output += 8;
529 channels -= 8;
530 } else {
531 if (channels & 4) {
532 vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_s8(vout01234567), 0); output += 4;
533 vout01234567 = vext_s8(vout01234567, vout01234567, 4);
534 }
535 if (channels & 2) {
536 vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_s8(vout01234567), 0); output += 2;
537 vout01234567 = vext_s8(vout01234567, vout01234567, 2);
538 }
539 if (channels & 1) {
540 vst1_lane_s8(output, vout01234567, 0); output += 1;
541 }
542 channels = 0;
543 }
544 } while (channels != 0);
545 }
546 }
547