1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-dwconv2d-chw/3x3p1-ssse3.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <tmmintrin.h>
13 
14 #include <xnnpack/dwconv.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_4x4(size_t input_height,size_t input_width,const float * input,const float * weights,const float * zero,float * output,uint32_t padding_top,const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_4x4(
19     size_t input_height,
20     size_t input_width,
21     const float* input,
22     const float* weights,
23     const float* zero,
24     float* output,
25     uint32_t padding_top,
26     const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(input_height != 0);
29   assert(input_width != 0);
30   assert(input_width % sizeof(float) == 0);
31   assert(padding_top == 1);
32 
33   const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
34   const __m128 vmax = _mm_load_ps(params->sse.max);
35   const __m128 vmin = _mm_load_ps(params->sse.min);
36 
37   const __m128 vbias = _mm_load1_ps(weights);
38   const __m128 vk00 = _mm_load1_ps(weights + 1);
39   const __m128 vk01 = _mm_load1_ps(weights + 2);
40   const __m128 vk02 = _mm_load1_ps(weights + 3);
41   const __m128 vk10 = _mm_load1_ps(weights + 4);
42   const __m128 vk11 = _mm_load1_ps(weights + 5);
43   const __m128 vk12 = _mm_load1_ps(weights + 6);
44   const __m128 vk20 = _mm_load1_ps(weights + 7);
45   const __m128 vk21 = _mm_load1_ps(weights + 8);
46   const __m128 vk22 = _mm_load1_ps(weights + 9);
47 
48   const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
49 
50   const float* i0 = zero;
51   const float* i1 = input;
52   const float* i2 = (const float*) ((uintptr_t) i1 + input_width);
53   const float* i3 = (const float*) ((uintptr_t) i2 + input_width);
54   const float* i4 = (const float*) ((uintptr_t) i3 + input_width);
55   const float* i5 = (const float*) ((uintptr_t) i4 + input_width);
56 
57   float* o0 = output;
58   float* o1 = (float*) ((uintptr_t) o0 + input_width);
59   float* o2 = (float*) ((uintptr_t) o1 + input_width);
60   float* o3 = (float*) ((uintptr_t) o2 + input_width);
61 
62   size_t output_height = input_height;
63   do {
64     if XNN_UNPREDICTABLE(output_height < 2) {
65       i2 = zero;
66       o1 = o0;
67     }
68     if XNN_UNPREDICTABLE(output_height < 3) {
69       i3 = zero;
70       o2 = o1;
71     }
72     if XNN_UNPREDICTABLE(output_height < 4) {
73       i4 = zero;
74       o3 = o2;
75     }
76     if XNN_UNPREDICTABLE(output_height < 5) {
77       i5 = zero;
78     }
79 
80     __m128 vi0x0123 = _mm_setzero_ps();
81     __m128 vi1x0123 = _mm_setzero_ps();
82     __m128 vi2x0123 = _mm_setzero_ps();
83     __m128 vi3x0123 = _mm_setzero_ps();
84     __m128 vi4x0123 = _mm_setzero_ps();
85     __m128 vi5x0123 = _mm_setzero_ps();
86 
87     __m128 vi0x4567 = _mm_loadu_ps(i0);
88     i0 += 4;
89     __m128 vi1x4567 = _mm_loadu_ps(i1);
90     i1 += 4;
91     __m128 vi2x4567 = _mm_loadu_ps(i2);
92     i2 += 4;
93     __m128 vi3x4567 = _mm_loadu_ps(i3);
94     i3 += 4;
95     __m128 vi4x4567 = _mm_loadu_ps(i4);
96     i4 += 4;
97     __m128 vi5x4567 = _mm_loadu_ps(i5);
98     i5 += 4;
99 
100     size_t w = input_width;
101     for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
102       const __m128 vi0x89AB = _mm_loadu_ps(i0);
103       i0 += 4;
104       const __m128 vi1x89AB = _mm_loadu_ps(i1);
105       i1 += 4;
106       const __m128 vi2x89AB = _mm_loadu_ps(i2);
107       i2 += 4;
108       const __m128 vi3x89AB = _mm_loadu_ps(i3);
109       i3 += 4;
110       const __m128 vi4x89AB = _mm_loadu_ps(i4);
111       i4 += 4;
112       const __m128 vi5x89AB = _mm_loadu_ps(i5);
113       i5 += 4;
114 
115       __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
116       __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
117       __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
118       __m128 vo3p0 = _mm_add_ps(vbias, _mm_mul_ps(vi3x4567, vk01));
119       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
120       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
121       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
122       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x4567, vk11));
123       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
124       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
125       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
126       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x4567, vk21));
127 
128       const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
129       const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
130       const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
131       const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
132       const __m128 vi4x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x4567), _mm_castps_si128(vi4x0123), 12));
133       const __m128 vi5x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x4567), _mm_castps_si128(vi5x0123), 12));
134 
135       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
136       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
137       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
138       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x3456, vk00));
139       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
140       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
141       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
142       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x3456, vk10));
143       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
144       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
145       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
146       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x3456, vk20));
147 
148       vi0x0123 = vi0x4567;
149       vi1x0123 = vi1x4567;
150       vi2x0123 = vi2x4567;
151       vi3x0123 = vi3x4567;
152       vi4x0123 = vi4x4567;
153       vi5x0123 = vi5x4567;
154 
155       const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x89AB), _mm_castps_si128(vi0x4567), 4));
156       const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x89AB), _mm_castps_si128(vi1x4567), 4));
157       const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x89AB), _mm_castps_si128(vi2x4567), 4));
158       const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x89AB), _mm_castps_si128(vi3x4567), 4));
159       const __m128 vi4x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x89AB), _mm_castps_si128(vi4x4567), 4));
160       const __m128 vi5x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x89AB), _mm_castps_si128(vi5x4567), 4));
161 
162       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
163       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
164       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
165       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x5678, vk02));
166       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
167       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
168       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
169       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x5678, vk12));
170       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
171       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
172       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
173       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x5678, vk22));
174 
175       vi0x4567 = vi0x89AB;
176       vi1x4567 = vi1x89AB;
177       vi2x4567 = vi2x89AB;
178       vi3x4567 = vi3x89AB;
179       vi4x4567 = vi4x89AB;
180       vi5x4567 = vi5x89AB;
181 
182 
183       __m128 vo0 = _mm_max_ps(vo0p0, vmin);
184       __m128 vo1 = _mm_max_ps(vo1p0, vmin);
185       __m128 vo2 = _mm_max_ps(vo2p0, vmin);
186       __m128 vo3 = _mm_max_ps(vo3p0, vmin);
187 
188       vo0 = _mm_min_ps(vo0, vmax);
189       vo1 = _mm_min_ps(vo1, vmax);
190       vo2 = _mm_min_ps(vo2, vmax);
191       vo3 = _mm_min_ps(vo3, vmax);
192 
193       _mm_storeu_ps(o3, vo3);
194       o3 += 4;
195       _mm_storeu_ps(o2, vo2);
196       o2 += 4;
197       _mm_storeu_ps(o1, vo1);
198       o1 += 4;
199       _mm_storeu_ps(o0, vo0);
200       o0 += 4;
201     }
202     // Always process the last block of 1..4 pixels.
203     assert(w >= 1 * sizeof(float));
204     assert(w <= 4 * sizeof(float));
205     {
206       vi0x4567 = _mm_and_ps(vmask, vi0x4567);
207       vi1x4567 = _mm_and_ps(vmask, vi1x4567);
208       vi2x4567 = _mm_and_ps(vmask, vi2x4567);
209       vi3x4567 = _mm_and_ps(vmask, vi3x4567);
210       vi4x4567 = _mm_and_ps(vmask, vi4x4567);
211       vi5x4567 = _mm_and_ps(vmask, vi5x4567);
212 
213       __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
214       __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
215       __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
216       __m128 vo3p0 = _mm_add_ps(vbias, _mm_mul_ps(vi3x4567, vk01));
217       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
218       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
219       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
220       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x4567, vk11));
221       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
222       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
223       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
224       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x4567, vk21));
225 
226       const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
227       const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
228       const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
229       const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
230       const __m128 vi4x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x4567), _mm_castps_si128(vi4x0123), 12));
231       const __m128 vi5x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x4567), _mm_castps_si128(vi5x0123), 12));
232 
233       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
234       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
235       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
236       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x3456, vk00));
237       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
238       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
239       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
240       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x3456, vk10));
241       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
242       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
243       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
244       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x3456, vk20));
245 
246       const __m128i vzero = _mm_setzero_si128();
247       const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi0x4567), 4));
248       const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi1x4567), 4));
249       const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi2x4567), 4));
250       const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi3x4567), 4));
251       const __m128 vi4x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi4x4567), 4));
252       const __m128 vi5x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi5x4567), 4));
253 
254       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
255       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
256       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
257       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x5678, vk02));
258       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
259       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
260       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
261       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x5678, vk12));
262       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
263       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
264       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
265       vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x5678, vk22));
266 
267 
268       __m128 vo0 = _mm_max_ps(vo0p0, vmin);
269       __m128 vo1 = _mm_max_ps(vo1p0, vmin);
270       __m128 vo2 = _mm_max_ps(vo2p0, vmin);
271       __m128 vo3 = _mm_max_ps(vo3p0, vmin);
272 
273       vo0 = _mm_min_ps(vo0, vmax);
274       vo1 = _mm_min_ps(vo1, vmax);
275       vo2 = _mm_min_ps(vo2, vmax);
276       vo3 = _mm_min_ps(vo3, vmax);
277 
278       if XNN_LIKELY(w == 4 * sizeof(float)) {
279         _mm_storeu_ps(o3, vo3);
280         o3 += 4;
281         _mm_storeu_ps(o2, vo2);
282         o2 += 4;
283         _mm_storeu_ps(o1, vo1);
284         o1 += 4;
285         _mm_storeu_ps(o0, vo0);
286         o0 += 4;
287       } else {
288         if (w & (2 * sizeof(float))) {
289           _mm_storel_pi((__m64*) o3, vo3);
290           o3 += 2;
291           _mm_storel_pi((__m64*) o2, vo2);
292           o2 += 2;
293           _mm_storel_pi((__m64*) o1, vo1);
294           o1 += 2;
295           _mm_storel_pi((__m64*) o0, vo0);
296           o0 += 2;
297 
298           vo0 = _mm_movehl_ps(vo0, vo0);
299           vo1 = _mm_movehl_ps(vo1, vo1);
300           vo2 = _mm_movehl_ps(vo2, vo2);
301           vo3 = _mm_movehl_ps(vo3, vo3);
302         }
303         if (w & (1 * sizeof(float))) {
304           _mm_store_ss(o3, vo3);
305           o3 += 1;
306           _mm_store_ss(o2, vo2);
307           o2 += 1;
308           _mm_store_ss(o1, vo1);
309           o1 += 1;
310           _mm_store_ss(o0, vo0);
311           o0 += 1;
312         }
313       }
314     }
315 
316     i0 = (const float*) ((uintptr_t) i4 - input_decrement);
317     i1 = (const float*) ((uintptr_t) i5 - input_decrement);
318     i2 = (const float*) ((uintptr_t) i1 + input_width);
319     i3 = (const float*) ((uintptr_t) i2 + input_width);
320     i4 = (const float*) ((uintptr_t) i3 + input_width);
321     i5 = (const float*) ((uintptr_t) i4 + input_width);
322 
323     o0 = o3;
324     o1 = (float*) ((uintptr_t) o0 + input_width);
325     o2 = (float*) ((uintptr_t) o1 + input_width);
326     o3 = (float*) ((uintptr_t) o2 + input_width);
327 
328     output_height = doz(output_height, 4);
329   } while (output_height != 0);
330 }
331