1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-gavgpool/multipass-sse.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <smmintrin.h>
13 
14 #include <xnnpack/gavgpool.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_qs8_gavgpool_minmax_ukernel_7p7x__sse41_c8_acc2(size_t rows,size_t channels,const int8_t * input,size_t input_stride,const int8_t * zero,int32_t * buffer,int8_t * output,const union xnn_qs8_avgpool_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qs8_gavgpool_minmax_ukernel_7p7x__sse41_c8_acc2(
19     size_t rows,
20     size_t channels,
21     const int8_t* input,
22     size_t input_stride,
23     const int8_t* zero,
24     int32_t* buffer,
25     int8_t* output,
26     const union xnn_qs8_avgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
27 {
28   assert(rows > 7);
29   assert(channels != 0);
30 
31   const int8_t* i0 = input;
32   const int8_t* i1 = (const int8_t*) ((uintptr_t) i0 + input_stride);
33   const int8_t* i2 = (const int8_t*) ((uintptr_t) i1 + input_stride);
34   const int8_t* i3 = (const int8_t*) ((uintptr_t) i2 + input_stride);
35   const int8_t* i4 = (const int8_t*) ((uintptr_t) i3 + input_stride);
36   const int8_t* i5 = (const int8_t*) ((uintptr_t) i4 + input_stride);
37   const int8_t* i6 = (const int8_t*) ((uintptr_t) i5 + input_stride);
38   const size_t input_increment = 7 * input_stride - round_up_po2(channels, 8);
39 
40   const __m128i vbias = _mm_load_si128((const __m128i*) params->sse2.bias);
41   int32_t* b = buffer;
42   size_t c = channels;
43   for (; c != 0; c = doz(c, 8)) {
44     const __m128i vxi0x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i0));
45     i0 += 8;
46     const __m128i vxi1x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i1));
47     i1 += 8;
48     const __m128i vxi2x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i2));
49     i2 += 8;
50     const __m128i vxi3x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i3));
51     i3 += 8;
52     const __m128i vxi4x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i4));
53     i4 += 8;
54     const __m128i vxi5x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i5));
55     i5 += 8;
56     const __m128i vxi6x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i6));
57     i6 += 8;
58 
59 
60     __m128i vacc0x01234567 = _mm_add_epi16(vxi0x01234567, vxi1x01234567);
61     __m128i vacc1x01234567 = _mm_add_epi16(vxi2x01234567, vxi3x01234567);
62 
63     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi4x01234567);
64     vacc1x01234567 = _mm_add_epi16(vacc1x01234567, vxi5x01234567);
65     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi6x01234567);
66 
67     // Add up all accumulators to vacc0x01234567
68     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vacc1x01234567);
69 
70     const __m128i vacc0123 = _mm_add_epi32(vbias, _mm_cvtepi16_epi32(vacc0x01234567));
71     const __m128i vacc4567 = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vacc0x01234567, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x01234567)));
72 
73     _mm_store_si128((__m128i*) b, vacc0123);
74     _mm_store_si128((__m128i*) (b + 4), vacc4567);
75     b += 8;
76   }
77 
78   for (rows -= 7; rows > 7; rows -= 7) {
79     i0 = (const int8_t*) ((uintptr_t) i0 + input_increment);
80     i1 = (const int8_t*) ((uintptr_t) i1 + input_increment);
81     i2 = (const int8_t*) ((uintptr_t) i2 + input_increment);
82     i3 = (const int8_t*) ((uintptr_t) i3 + input_increment);
83     i4 = (const int8_t*) ((uintptr_t) i4 + input_increment);
84     i5 = (const int8_t*) ((uintptr_t) i5 + input_increment);
85     i6 = (const int8_t*) ((uintptr_t) i6 + input_increment);
86 
87     int32_t* b = buffer;
88     size_t c = channels;
89     for (; c != 0; c = doz(c, 8)) {
90       const __m128i vxi0x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i0));
91       i0 += 8;
92       const __m128i vxi1x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i1));
93       i1 += 8;
94       const __m128i vxi2x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i2));
95       i2 += 8;
96       const __m128i vxi3x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i3));
97       i3 += 8;
98       const __m128i vxi4x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i4));
99       i4 += 8;
100       const __m128i vxi5x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i5));
101       i5 += 8;
102       const __m128i vxi6x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i6));
103       i6 += 8;
104 
105 
106       __m128i vacc0x01234567 = _mm_add_epi16(vxi0x01234567, vxi1x01234567);
107       __m128i vacc1x01234567 = _mm_add_epi16(vxi2x01234567, vxi3x01234567);
108 
109       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi4x01234567);
110       vacc1x01234567 = _mm_add_epi16(vacc1x01234567, vxi5x01234567);
111       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi6x01234567);
112 
113       // Add up all accumulators to vacc0x01234567
114       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vacc1x01234567);
115 
116       const __m128i vacc0123 = _mm_add_epi32(_mm_cvtepi16_epi32(vacc0x01234567), _mm_load_si128((const __m128i*) (b + 0)));
117       const __m128i vacc4567 = _mm_add_epi32(_mm_unpackhi_epi16(vacc0x01234567, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x01234567)), _mm_load_si128((const __m128i*) (b + 4)));
118 
119       _mm_store_si128((__m128i*) b, vacc0123);
120       _mm_store_si128((__m128i*) (b + 4), vacc4567);
121       b += 8;
122     }
123   }
124 
125   i0 = (const int8_t*) ((uintptr_t) i0 + input_increment);
126   i1 = (const int8_t*) ((uintptr_t) i1 + input_increment);
127   if XNN_UNPREDICTABLE(rows < 2) {
128     i1 = zero;
129   }
130   i2 = (const int8_t*) ((uintptr_t) i2 + input_increment);
131   if XNN_UNPREDICTABLE(rows <= 2) {
132     i2 = zero;
133   }
134   i3 = (const int8_t*) ((uintptr_t) i3 + input_increment);
135   if XNN_UNPREDICTABLE(rows < 4) {
136     i3 = zero;
137   }
138   i4 = (const int8_t*) ((uintptr_t) i4 + input_increment);
139   if XNN_UNPREDICTABLE(rows <= 4) {
140     i4 = zero;
141   }
142   i5 = (const int8_t*) ((uintptr_t) i5 + input_increment);
143   if XNN_UNPREDICTABLE(rows < 6) {
144     i5 = zero;
145   }
146   i6 = (const int8_t*) ((uintptr_t) i6 + input_increment);
147   if XNN_UNPREDICTABLE(rows <= 6) {
148     i6 = zero;
149   }
150 
151   const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
152   const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
153   const __m128i vshift = _mm_loadl_epi64((const __m128i*) params->sse2.shift);
154   while (channels >= 8) {
155     const __m128i vxi0x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i0));
156     i0 += 8;
157     const __m128i vxi1x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i1));
158     i1 += 8;
159     const __m128i vxi2x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i2));
160     i2 += 8;
161     const __m128i vxi3x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i3));
162     i3 += 8;
163     const __m128i vxi4x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i4));
164     i4 += 8;
165     const __m128i vxi5x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i5));
166     i5 += 8;
167     const __m128i vxi6x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i6));
168     i6 += 8;
169 
170 
171     __m128i vacc0x01234567 = _mm_add_epi16(vxi0x01234567, vxi1x01234567);
172     __m128i vacc1x01234567 = _mm_add_epi16(vxi2x01234567, vxi3x01234567);
173 
174     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi4x01234567);
175     vacc1x01234567 = _mm_add_epi16(vacc1x01234567, vxi5x01234567);
176     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi6x01234567);
177 
178     // Add up all accumulators to vacc0x01234567
179     vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vacc1x01234567);
180 
181     const __m128i vacc0123 = _mm_add_epi32(_mm_cvtepi16_epi32(vacc0x01234567), _mm_load_si128((const __m128i*) (buffer + 0)));
182     const __m128i vacc4567 = _mm_add_epi32(_mm_unpackhi_epi16(vacc0x01234567, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x01234567)), _mm_load_si128((const __m128i*) (buffer + 4)));
183     buffer += 8;
184 
185     const __m128i vabsacc0123 = _mm_abs_epi32(vacc0123);
186     const __m128i vabsacc4567 = _mm_abs_epi32(vacc4567);
187 
188     const __m128i vabsacc13 = _mm_shuffle_epi32(vabsacc0123, _MM_SHUFFLE(3, 3, 1, 1));
189     const __m128i vabsacc57 = _mm_shuffle_epi32(vabsacc4567, _MM_SHUFFLE(3, 3, 1, 1));
190 
191     const __m128i vabsprod02 = _mm_mul_epu32(vabsacc0123, vmultiplier);
192     const __m128i vabsprod13 = _mm_mul_epu32(vabsacc13, vmultiplier);
193     const __m128i vabsprod46 = _mm_mul_epu32(vabsacc4567, vmultiplier);
194     const __m128i vabsprod57 = _mm_mul_epu32(vabsacc57, vmultiplier);
195 
196     const __m128i vabsout02 = _mm_srl_epi64(_mm_add_epi64(vabsprod02, vrounding), vshift);
197     const __m128i vabsout13 = _mm_srl_epi64(_mm_add_epi64(vabsprod13, vrounding), vshift);
198     const __m128i vabsout46 = _mm_srl_epi64(_mm_add_epi64(vabsprod46, vrounding), vshift);
199     const __m128i vabsout57 = _mm_srl_epi64(_mm_add_epi64(vabsprod57, vrounding), vshift);
200 
201     const __m128i vabsout0123 = _mm_blend_epi16(vabsout02, _mm_shuffle_epi32(vabsout13, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC);
202     const __m128i vabsout4567 = _mm_blend_epi16(vabsout46, _mm_shuffle_epi32(vabsout57, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC);
203 
204     const __m128i vout0123 = _mm_sign_epi32(vabsout0123, vacc0123);
205     const __m128i vout4567 = _mm_sign_epi32(vabsout4567, vacc4567);
206 
207     const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
208     __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(vout0123, vout4567), voutput_zero_point);
209 
210     const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min);
211     const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max);
212     vout01234567 = _mm_min_epi16(_mm_max_epi16(vout01234567, voutput_min), voutput_max);
213 
214     __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
215 
216     _mm_storel_epi64((__m128i*) output, vout0123456701234567);
217     output += 8;
218 
219     channels -= 8;
220   }
221   if XNN_UNLIKELY(channels != 0) {
222     {
223       const __m128i vxi0x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i0));
224       i0 += 8;
225       const __m128i vxi1x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i1));
226       i1 += 8;
227       const __m128i vxi2x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i2));
228       i2 += 8;
229       const __m128i vxi3x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i3));
230       i3 += 8;
231       const __m128i vxi4x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i4));
232       i4 += 8;
233       const __m128i vxi5x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i5));
234       i5 += 8;
235       const __m128i vxi6x01234567 = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i6));
236       i6 += 8;
237 
238 
239       __m128i vacc0x01234567 = _mm_add_epi16(vxi0x01234567, vxi1x01234567);
240       __m128i vacc1x01234567 = _mm_add_epi16(vxi2x01234567, vxi3x01234567);
241 
242       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi4x01234567);
243       vacc1x01234567 = _mm_add_epi16(vacc1x01234567, vxi5x01234567);
244       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vxi6x01234567);
245 
246       // Add up all accumulators to vacc0x01234567
247       vacc0x01234567 = _mm_add_epi16(vacc0x01234567, vacc1x01234567);
248 
249       const __m128i vacc0123 = _mm_add_epi32(_mm_cvtepi16_epi32(vacc0x01234567), _mm_load_si128((const __m128i*) buffer));
250       const __m128i vacc4567 = _mm_add_epi32(_mm_unpackhi_epi16(vacc0x01234567, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x01234567)), _mm_load_si128((const __m128i*) (buffer + 4)));
251       buffer += 8;
252 
253       const __m128i vabsacc0123 = _mm_abs_epi32(vacc0123);
254       const __m128i vabsacc4567 = _mm_abs_epi32(vacc4567);
255 
256       const __m128i vabsacc13 = _mm_shuffle_epi32(vabsacc0123, _MM_SHUFFLE(3, 3, 1, 1));
257       const __m128i vabsacc57 = _mm_shuffle_epi32(vabsacc4567, _MM_SHUFFLE(3, 3, 1, 1));
258 
259       const __m128i vabsprod02 = _mm_mul_epu32(vabsacc0123, vmultiplier);
260       const __m128i vabsprod13 = _mm_mul_epu32(vabsacc13, vmultiplier);
261       const __m128i vabsprod46 = _mm_mul_epu32(vabsacc4567, vmultiplier);
262       const __m128i vabsprod57 = _mm_mul_epu32(vabsacc57, vmultiplier);
263 
264       const __m128i vabsout02 = _mm_srl_epi64(_mm_add_epi64(vabsprod02, vrounding), vshift);
265       const __m128i vabsout13 = _mm_srl_epi64(_mm_add_epi64(vabsprod13, vrounding), vshift);
266       const __m128i vabsout46 = _mm_srl_epi64(_mm_add_epi64(vabsprod46, vrounding), vshift);
267       const __m128i vabsout57 = _mm_srl_epi64(_mm_add_epi64(vabsprod57, vrounding), vshift);
268 
269       const __m128i vabsout0123 = _mm_blend_epi16(vabsout02, _mm_shuffle_epi32(vabsout13, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC);
270       const __m128i vabsout4567 = _mm_blend_epi16(vabsout46, _mm_shuffle_epi32(vabsout57, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC);
271 
272       const __m128i vout0123 = _mm_sign_epi32(vabsout0123, vacc0123);
273       const __m128i vout4567 = _mm_sign_epi32(vabsout4567, vacc4567);
274 
275       const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
276       __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(vout0123, vout4567), voutput_zero_point);
277 
278       const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min);
279       const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max);
280       vout01234567 = _mm_min_epi16(_mm_max_epi16(vout01234567, voutput_min), voutput_max);
281 
282       __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
283 
284       if (channels & 4) {
285         *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout0123456701234567);
286         vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
287         output += 4;
288       }
289       if (channels & 2) {
290         *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout0123456701234567, 0);
291         vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
292         output += 2;
293       }
294       if (channels & 1) {
295         *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
296       }
297     }
298   }
299 }
300