1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-dwconv2d-chw/3x3p1-sse.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <xmmintrin.h>
13 
14 #include <xnnpack/dwconv.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_f32_dwconv2d_chw_ukernel_3x3p1__sse_3x4(size_t input_height,size_t input_width,const float * input,const float * weights,const float * zero,float * output,uint32_t padding_top,const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_f32_dwconv2d_chw_ukernel_3x3p1__sse_3x4(
19     size_t input_height,
20     size_t input_width,
21     const float* input,
22     const float* weights,
23     const float* zero,
24     float* output,
25     uint32_t padding_top,
26     const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(input_height != 0);
29   assert(input_width != 0);
30   assert(input_width % sizeof(float) == 0);
31   assert(padding_top == 1);
32 
33   const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
34   const __m128 vmax = _mm_load_ps(params->sse.max);
35   const __m128 vmin = _mm_load_ps(params->sse.min);
36 
37   const __m128 vbias = _mm_load1_ps(weights);
38   const __m128 vk00 = _mm_load1_ps(weights + 1);
39   const __m128 vk01 = _mm_load1_ps(weights + 2);
40   const __m128 vk02 = _mm_load1_ps(weights + 3);
41   const __m128 vk10 = _mm_load1_ps(weights + 4);
42   const __m128 vk11 = _mm_load1_ps(weights + 5);
43   const __m128 vk12 = _mm_load1_ps(weights + 6);
44   const __m128 vk20 = _mm_load1_ps(weights + 7);
45   const __m128 vk21 = _mm_load1_ps(weights + 8);
46   const __m128 vk22 = _mm_load1_ps(weights + 9);
47 
48   const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
49 
50   const float* i0 = zero;
51   const float* i1 = input;
52   const float* i2 = (const float*) ((uintptr_t) i1 + input_width);
53   const float* i3 = (const float*) ((uintptr_t) i2 + input_width);
54   const float* i4 = (const float*) ((uintptr_t) i3 + input_width);
55 
56   float* o0 = output;
57   float* o1 = (float*) ((uintptr_t) o0 + input_width);
58   float* o2 = (float*) ((uintptr_t) o1 + input_width);
59 
60   size_t output_height = input_height;
61   do {
62     if XNN_UNPREDICTABLE(output_height < 2) {
63       i2 = zero;
64       o1 = o0;
65     }
66     if XNN_UNPREDICTABLE(output_height < 3) {
67       i3 = zero;
68       o2 = o1;
69     }
70     if XNN_UNPREDICTABLE(output_height < 4) {
71       i4 = zero;
72     }
73 
74     // vi0x3012 = ( vi02, vi01, vi{M}0, vi{M}3 )
75     __m128 vi0x3012 = _mm_setzero_ps();
76     // vi1x3012 = ( vi12, vi11, vi{M}0, vi{M}3 )
77     __m128 vi1x3012 = _mm_setzero_ps();
78     // vi2x3012 = ( vi22, vi21, vi{M}0, vi{M}3 )
79     __m128 vi2x3012 = _mm_setzero_ps();
80     // vi3x3012 = ( vi32, vi31, vi{M}0, vi{M}3 )
81     __m128 vi3x3012 = _mm_setzero_ps();
82     // vi4x3012 = ( vi42, vi41, vi{M}0, vi{M}3 )
83     __m128 vi4x3012 = _mm_setzero_ps();
84 
85     __m128 vi0x4567 = _mm_loadu_ps(i0);
86     i0 += 4;
87     __m128 vi1x4567 = _mm_loadu_ps(i1);
88     i1 += 4;
89     __m128 vi2x4567 = _mm_loadu_ps(i2);
90     i2 += 4;
91     __m128 vi3x4567 = _mm_loadu_ps(i3);
92     i3 += 4;
93     __m128 vi4x4567 = _mm_loadu_ps(i4);
94     i4 += 4;
95 
96     size_t w = input_width;
97     for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
98       // vi0x89AB = ( vi0B, vi0A, vi09, vi08 )
99       const __m128 vi0x89AB = _mm_loadu_ps(i0);
100       i0 += 4;
101       // vi1x89AB = ( vi1B, vi1A, vi19, vi18 )
102       const __m128 vi1x89AB = _mm_loadu_ps(i1);
103       i1 += 4;
104       // vi2x89AB = ( vi2B, vi2A, vi29, vi28 )
105       const __m128 vi2x89AB = _mm_loadu_ps(i2);
106       i2 += 4;
107       // vi3x89AB = ( vi3B, vi3A, vi39, vi38 )
108       const __m128 vi3x89AB = _mm_loadu_ps(i3);
109       i3 += 4;
110       // vi4x89AB = ( vi4B, vi4A, vi49, vi48 )
111       const __m128 vi4x89AB = _mm_loadu_ps(i4);
112       i4 += 4;
113 
114       // vi0x7456 = ( vi06, vi05, vi04, vi07 )
115       const __m128 vi0x7456 = _mm_shuffle_ps(vi0x4567, vi0x4567, _MM_SHUFFLE(2, 1, 0, 3));
116       // vi1x7456 = ( vi16, vi15, vi14, vi17 )
117       const __m128 vi1x7456 = _mm_shuffle_ps(vi1x4567, vi1x4567, _MM_SHUFFLE(2, 1, 0, 3));
118       // vi2x7456 = ( vi26, vi25, vi24, vi27 )
119       const __m128 vi2x7456 = _mm_shuffle_ps(vi2x4567, vi2x4567, _MM_SHUFFLE(2, 1, 0, 3));
120       // vi3x7456 = ( vi36, vi35, vi34, vi37 )
121       const __m128 vi3x7456 = _mm_shuffle_ps(vi3x4567, vi3x4567, _MM_SHUFFLE(2, 1, 0, 3));
122       // vi4x7456 = ( vi46, vi45, vi44, vi47 )
123       const __m128 vi4x7456 = _mm_shuffle_ps(vi4x4567, vi4x4567, _MM_SHUFFLE(2, 1, 0, 3));
124 
125       __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
126       __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
127       __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
128       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
129       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
130       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
131       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
132       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
133       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
134 
135       // vi0x3456 = ( vi06, vi05, vi04, vi03 )
136       const __m128 vi0x3456 = _mm_move_ss(vi0x7456, vi0x3012);
137       // vi1x3456 = ( vi16, vi15, vi14, vi13 )
138       const __m128 vi1x3456 = _mm_move_ss(vi1x7456, vi1x3012);
139       // vi2x3456 = ( vi26, vi25, vi24, vi23 )
140       const __m128 vi2x3456 = _mm_move_ss(vi2x7456, vi2x3012);
141       // vi3x3456 = ( vi36, vi35, vi34, vi33 )
142       const __m128 vi3x3456 = _mm_move_ss(vi3x7456, vi3x3012);
143       // vi4x3456 = ( vi46, vi45, vi44, vi43 )
144       const __m128 vi4x3456 = _mm_move_ss(vi4x7456, vi4x3012);
145 
146       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
147       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
148       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
149       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
150       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
151       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
152       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
153       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
154       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
155 
156       vi0x3012 = vi0x7456;
157       vi1x3012 = vi1x7456;
158       vi2x3012 = vi2x7456;
159       vi3x3012 = vi3x7456;
160       vi4x3012 = vi4x7456;
161 
162       // vi0x8567 = ( vi07, vi06, vi05, vi08 )
163       const __m128 vi0x8567 = _mm_move_ss(vi0x4567, vi0x89AB);
164       // vi1x8567 = ( vi17, vi16, vi15, vi18 )
165       const __m128 vi1x8567 = _mm_move_ss(vi1x4567, vi1x89AB);
166       // vi2x8567 = ( vi27, vi26, vi25, vi28 )
167       const __m128 vi2x8567 = _mm_move_ss(vi2x4567, vi2x89AB);
168       // vi3x8567 = ( vi37, vi36, vi35, vi38 )
169       const __m128 vi3x8567 = _mm_move_ss(vi3x4567, vi3x89AB);
170       // vi4x8567 = ( vi47, vi46, vi45, vi48 )
171       const __m128 vi4x8567 = _mm_move_ss(vi4x4567, vi4x89AB);
172 
173       // vi0x5678 = ( vi08, vi07, vi06, vi05 )
174       const __m128 vi0x5678 = _mm_shuffle_ps(vi0x8567, vi0x8567, _MM_SHUFFLE(0, 3, 2, 1));
175       // vi1x5678 = ( vi18, vi17, vi16, vi15 )
176       const __m128 vi1x5678 = _mm_shuffle_ps(vi1x8567, vi1x8567, _MM_SHUFFLE(0, 3, 2, 1));
177       // vi2x5678 = ( vi28, vi27, vi26, vi25 )
178       const __m128 vi2x5678 = _mm_shuffle_ps(vi2x8567, vi2x8567, _MM_SHUFFLE(0, 3, 2, 1));
179       // vi3x5678 = ( vi38, vi37, vi36, vi35 )
180       const __m128 vi3x5678 = _mm_shuffle_ps(vi3x8567, vi3x8567, _MM_SHUFFLE(0, 3, 2, 1));
181       // vi4x5678 = ( vi48, vi47, vi46, vi45 )
182       const __m128 vi4x5678 = _mm_shuffle_ps(vi4x8567, vi4x8567, _MM_SHUFFLE(0, 3, 2, 1));
183 
184       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
185       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
186       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
187       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
188       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
189       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
190       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
191       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
192       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
193 
194       vi0x4567 = vi0x89AB;
195       vi1x4567 = vi1x89AB;
196       vi2x4567 = vi2x89AB;
197       vi3x4567 = vi3x89AB;
198       vi4x4567 = vi4x89AB;
199 
200 
201       __m128 vo0 = _mm_max_ps(vo0p0, vmin);
202       __m128 vo1 = _mm_max_ps(vo1p0, vmin);
203       __m128 vo2 = _mm_max_ps(vo2p0, vmin);
204 
205       vo0 = _mm_min_ps(vo0, vmax);
206       vo1 = _mm_min_ps(vo1, vmax);
207       vo2 = _mm_min_ps(vo2, vmax);
208 
209       _mm_storeu_ps(o2, vo2);
210       o2 += 4;
211       _mm_storeu_ps(o1, vo1);
212       o1 += 4;
213       _mm_storeu_ps(o0, vo0);
214       o0 += 4;
215     }
216     // Always process the last block of 1..4 pixels.
217     assert(w >= 1 * sizeof(float));
218     assert(w <= 4 * sizeof(float));
219     {
220       vi0x4567 = _mm_and_ps(vmask, vi0x4567);
221       vi1x4567 = _mm_and_ps(vmask, vi1x4567);
222       vi2x4567 = _mm_and_ps(vmask, vi2x4567);
223       vi3x4567 = _mm_and_ps(vmask, vi3x4567);
224       vi4x4567 = _mm_and_ps(vmask, vi4x4567);
225 
226       // vi0x7456 = ( vi06, vi05, vi04, vi07 )
227       const __m128 vi0x7456 = _mm_shuffle_ps(vi0x4567, vi0x4567, _MM_SHUFFLE(2, 1, 0, 3));
228       // vi1x7456 = ( vi16, vi15, vi14, vi17 )
229       const __m128 vi1x7456 = _mm_shuffle_ps(vi1x4567, vi1x4567, _MM_SHUFFLE(2, 1, 0, 3));
230       // vi2x7456 = ( vi26, vi25, vi24, vi27 )
231       const __m128 vi2x7456 = _mm_shuffle_ps(vi2x4567, vi2x4567, _MM_SHUFFLE(2, 1, 0, 3));
232       // vi3x7456 = ( vi36, vi35, vi34, vi37 )
233       const __m128 vi3x7456 = _mm_shuffle_ps(vi3x4567, vi3x4567, _MM_SHUFFLE(2, 1, 0, 3));
234       // vi4x7456 = ( vi46, vi45, vi44, vi47 )
235       const __m128 vi4x7456 = _mm_shuffle_ps(vi4x4567, vi4x4567, _MM_SHUFFLE(2, 1, 0, 3));
236 
237       __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
238       __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
239       __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
240       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
241       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
242       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
243       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
244       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
245       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
246 
247       // vi0x3456 = ( vi06, vi05, vi04, vi03 )
248       const __m128 vi0x3456 = _mm_move_ss(vi0x7456, vi0x3012);
249       // vi1x3456 = ( vi16, vi15, vi14, vi13 )
250       const __m128 vi1x3456 = _mm_move_ss(vi1x7456, vi1x3012);
251       // vi2x3456 = ( vi26, vi25, vi24, vi23 )
252       const __m128 vi2x3456 = _mm_move_ss(vi2x7456, vi2x3012);
253       // vi3x3456 = ( vi36, vi35, vi34, vi33 )
254       const __m128 vi3x3456 = _mm_move_ss(vi3x7456, vi3x3012);
255       // vi4x3456 = ( vi46, vi45, vi44, vi43 )
256       const __m128 vi4x3456 = _mm_move_ss(vi4x7456, vi4x3012);
257 
258       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
259       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
260       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
261       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
262       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
263       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
264       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
265       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
266       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
267 
268       const __m128 vzero = _mm_setzero_ps();
269       // vi0x8567 = ( vi07, vi06, vi05, 0.0 )
270       const __m128 vi0x8567 = _mm_move_ss(vi0x4567, vzero);
271       // vi1x8567 = ( vi17, vi16, vi15, 0.0 )
272       const __m128 vi1x8567 = _mm_move_ss(vi1x4567, vzero);
273       // vi2x8567 = ( vi27, vi26, vi25, 0.0 )
274       const __m128 vi2x8567 = _mm_move_ss(vi2x4567, vzero);
275       // vi3x8567 = ( vi37, vi36, vi35, 0.0 )
276       const __m128 vi3x8567 = _mm_move_ss(vi3x4567, vzero);
277       // vi4x8567 = ( vi47, vi46, vi45, 0.0 )
278       const __m128 vi4x8567 = _mm_move_ss(vi4x4567, vzero);
279 
280       // vi0x5678 = ( vi08, vi07, vi06, vi05 )
281       const __m128 vi0x5678 = _mm_shuffle_ps(vi0x8567, vi0x8567, _MM_SHUFFLE(0, 3, 2, 1));
282       // vi1x5678 = ( vi18, vi17, vi16, vi15 )
283       const __m128 vi1x5678 = _mm_shuffle_ps(vi1x8567, vi1x8567, _MM_SHUFFLE(0, 3, 2, 1));
284       // vi2x5678 = ( vi28, vi27, vi26, vi25 )
285       const __m128 vi2x5678 = _mm_shuffle_ps(vi2x8567, vi2x8567, _MM_SHUFFLE(0, 3, 2, 1));
286       // vi3x5678 = ( vi38, vi37, vi36, vi35 )
287       const __m128 vi3x5678 = _mm_shuffle_ps(vi3x8567, vi3x8567, _MM_SHUFFLE(0, 3, 2, 1));
288       // vi4x5678 = ( vi48, vi47, vi46, vi45 )
289       const __m128 vi4x5678 = _mm_shuffle_ps(vi4x8567, vi4x8567, _MM_SHUFFLE(0, 3, 2, 1));
290 
291       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
292       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
293       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
294       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
295       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
296       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
297       vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
298       vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
299       vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
300 
301 
302       __m128 vo0 = _mm_max_ps(vo0p0, vmin);
303       __m128 vo1 = _mm_max_ps(vo1p0, vmin);
304       __m128 vo2 = _mm_max_ps(vo2p0, vmin);
305 
306       vo0 = _mm_min_ps(vo0, vmax);
307       vo1 = _mm_min_ps(vo1, vmax);
308       vo2 = _mm_min_ps(vo2, vmax);
309 
310       if XNN_LIKELY(w == 4 * sizeof(float)) {
311         _mm_storeu_ps(o2, vo2);
312         o2 += 4;
313         _mm_storeu_ps(o1, vo1);
314         o1 += 4;
315         _mm_storeu_ps(o0, vo0);
316         o0 += 4;
317       } else {
318         if (w & (2 * sizeof(float))) {
319           _mm_storel_pi((__m64*) o2, vo2);
320           o2 += 2;
321           _mm_storel_pi((__m64*) o1, vo1);
322           o1 += 2;
323           _mm_storel_pi((__m64*) o0, vo0);
324           o0 += 2;
325 
326           vo0 = _mm_movehl_ps(vo0, vo0);
327           vo1 = _mm_movehl_ps(vo1, vo1);
328           vo2 = _mm_movehl_ps(vo2, vo2);
329         }
330         if (w & (1 * sizeof(float))) {
331           _mm_store_ss(o2, vo2);
332           o2 += 1;
333           _mm_store_ss(o1, vo1);
334           o1 += 1;
335           _mm_store_ss(o0, vo0);
336           o0 += 1;
337         }
338       }
339     }
340 
341     i0 = (const float*) ((uintptr_t) i3 - input_decrement);
342     i1 = (const float*) ((uintptr_t) i4 - input_decrement);
343     i2 = (const float*) ((uintptr_t) i1 + input_width);
344     i3 = (const float*) ((uintptr_t) i2 + input_width);
345     i4 = (const float*) ((uintptr_t) i3 + input_width);
346 
347     o0 = o2;
348     o1 = (float*) ((uintptr_t) o0 + input_width);
349     o2 = (float*) ((uintptr_t) o1 + input_width);
350 
351     output_height = doz(output_height, 3);
352   } while (output_height != 0);
353 }
354