1 // Auto-generated file. Do not edit!
2 // Template: src/f32-dwconv2d-chw/3x3p1-ssse3.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <tmmintrin.h>
13
14 #include <xnnpack/dwconv.h>
15 #include <xnnpack/math.h>
16
17
xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_4x4(size_t input_height,size_t input_width,const float * input,const float * weights,const float * zero,float * output,uint32_t padding_top,const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_f32_dwconv2d_chw_ukernel_3x3p1__ssse3_4x4(
19 size_t input_height,
20 size_t input_width,
21 const float* input,
22 const float* weights,
23 const float* zero,
24 float* output,
25 uint32_t padding_top,
26 const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(input_height != 0);
29 assert(input_width != 0);
30 assert(input_width % sizeof(float) == 0);
31 assert(padding_top == 1);
32
33 const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
34 const __m128 vmax = _mm_load_ps(params->sse.max);
35 const __m128 vmin = _mm_load_ps(params->sse.min);
36
37 const __m128 vbias = _mm_load1_ps(weights);
38 const __m128 vk00 = _mm_load1_ps(weights + 1);
39 const __m128 vk01 = _mm_load1_ps(weights + 2);
40 const __m128 vk02 = _mm_load1_ps(weights + 3);
41 const __m128 vk10 = _mm_load1_ps(weights + 4);
42 const __m128 vk11 = _mm_load1_ps(weights + 5);
43 const __m128 vk12 = _mm_load1_ps(weights + 6);
44 const __m128 vk20 = _mm_load1_ps(weights + 7);
45 const __m128 vk21 = _mm_load1_ps(weights + 8);
46 const __m128 vk22 = _mm_load1_ps(weights + 9);
47
48 const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
49
50 const float* i0 = zero;
51 const float* i1 = input;
52 const float* i2 = (const float*) ((uintptr_t) i1 + input_width);
53 const float* i3 = (const float*) ((uintptr_t) i2 + input_width);
54 const float* i4 = (const float*) ((uintptr_t) i3 + input_width);
55 const float* i5 = (const float*) ((uintptr_t) i4 + input_width);
56
57 float* o0 = output;
58 float* o1 = (float*) ((uintptr_t) o0 + input_width);
59 float* o2 = (float*) ((uintptr_t) o1 + input_width);
60 float* o3 = (float*) ((uintptr_t) o2 + input_width);
61
62 size_t output_height = input_height;
63 do {
64 if XNN_UNPREDICTABLE(output_height < 2) {
65 i2 = zero;
66 o1 = o0;
67 }
68 if XNN_UNPREDICTABLE(output_height < 3) {
69 i3 = zero;
70 o2 = o1;
71 }
72 if XNN_UNPREDICTABLE(output_height < 4) {
73 i4 = zero;
74 o3 = o2;
75 }
76 if XNN_UNPREDICTABLE(output_height < 5) {
77 i5 = zero;
78 }
79
80 __m128 vi0x0123 = _mm_setzero_ps();
81 __m128 vi1x0123 = _mm_setzero_ps();
82 __m128 vi2x0123 = _mm_setzero_ps();
83 __m128 vi3x0123 = _mm_setzero_ps();
84 __m128 vi4x0123 = _mm_setzero_ps();
85 __m128 vi5x0123 = _mm_setzero_ps();
86
87 __m128 vi0x4567 = _mm_loadu_ps(i0);
88 i0 += 4;
89 __m128 vi1x4567 = _mm_loadu_ps(i1);
90 i1 += 4;
91 __m128 vi2x4567 = _mm_loadu_ps(i2);
92 i2 += 4;
93 __m128 vi3x4567 = _mm_loadu_ps(i3);
94 i3 += 4;
95 __m128 vi4x4567 = _mm_loadu_ps(i4);
96 i4 += 4;
97 __m128 vi5x4567 = _mm_loadu_ps(i5);
98 i5 += 4;
99
100 size_t w = input_width;
101 for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
102 const __m128 vi0x89AB = _mm_loadu_ps(i0);
103 i0 += 4;
104 const __m128 vi1x89AB = _mm_loadu_ps(i1);
105 i1 += 4;
106 const __m128 vi2x89AB = _mm_loadu_ps(i2);
107 i2 += 4;
108 const __m128 vi3x89AB = _mm_loadu_ps(i3);
109 i3 += 4;
110 const __m128 vi4x89AB = _mm_loadu_ps(i4);
111 i4 += 4;
112 const __m128 vi5x89AB = _mm_loadu_ps(i5);
113 i5 += 4;
114
115 __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
116 __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
117 __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
118 __m128 vo3p0 = _mm_add_ps(vbias, _mm_mul_ps(vi3x4567, vk01));
119 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
120 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
121 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
122 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x4567, vk11));
123 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
124 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
125 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
126 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x4567, vk21));
127
128 const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
129 const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
130 const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
131 const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
132 const __m128 vi4x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x4567), _mm_castps_si128(vi4x0123), 12));
133 const __m128 vi5x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x4567), _mm_castps_si128(vi5x0123), 12));
134
135 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
136 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
137 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
138 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x3456, vk00));
139 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
140 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
141 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
142 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x3456, vk10));
143 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
144 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
145 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
146 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x3456, vk20));
147
148 vi0x0123 = vi0x4567;
149 vi1x0123 = vi1x4567;
150 vi2x0123 = vi2x4567;
151 vi3x0123 = vi3x4567;
152 vi4x0123 = vi4x4567;
153 vi5x0123 = vi5x4567;
154
155 const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x89AB), _mm_castps_si128(vi0x4567), 4));
156 const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x89AB), _mm_castps_si128(vi1x4567), 4));
157 const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x89AB), _mm_castps_si128(vi2x4567), 4));
158 const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x89AB), _mm_castps_si128(vi3x4567), 4));
159 const __m128 vi4x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x89AB), _mm_castps_si128(vi4x4567), 4));
160 const __m128 vi5x5678 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x89AB), _mm_castps_si128(vi5x4567), 4));
161
162 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
163 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
164 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
165 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x5678, vk02));
166 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
167 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
168 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
169 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x5678, vk12));
170 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
171 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
172 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
173 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x5678, vk22));
174
175 vi0x4567 = vi0x89AB;
176 vi1x4567 = vi1x89AB;
177 vi2x4567 = vi2x89AB;
178 vi3x4567 = vi3x89AB;
179 vi4x4567 = vi4x89AB;
180 vi5x4567 = vi5x89AB;
181
182
183 __m128 vo0 = _mm_max_ps(vo0p0, vmin);
184 __m128 vo1 = _mm_max_ps(vo1p0, vmin);
185 __m128 vo2 = _mm_max_ps(vo2p0, vmin);
186 __m128 vo3 = _mm_max_ps(vo3p0, vmin);
187
188 vo0 = _mm_min_ps(vo0, vmax);
189 vo1 = _mm_min_ps(vo1, vmax);
190 vo2 = _mm_min_ps(vo2, vmax);
191 vo3 = _mm_min_ps(vo3, vmax);
192
193 _mm_storeu_ps(o3, vo3);
194 o3 += 4;
195 _mm_storeu_ps(o2, vo2);
196 o2 += 4;
197 _mm_storeu_ps(o1, vo1);
198 o1 += 4;
199 _mm_storeu_ps(o0, vo0);
200 o0 += 4;
201 }
202 // Always process the last block of 1..4 pixels.
203 assert(w >= 1 * sizeof(float));
204 assert(w <= 4 * sizeof(float));
205 {
206 vi0x4567 = _mm_and_ps(vmask, vi0x4567);
207 vi1x4567 = _mm_and_ps(vmask, vi1x4567);
208 vi2x4567 = _mm_and_ps(vmask, vi2x4567);
209 vi3x4567 = _mm_and_ps(vmask, vi3x4567);
210 vi4x4567 = _mm_and_ps(vmask, vi4x4567);
211 vi5x4567 = _mm_and_ps(vmask, vi5x4567);
212
213 __m128 vo0p0 = _mm_add_ps(vbias, _mm_mul_ps(vi0x4567, vk01));
214 __m128 vo1p0 = _mm_add_ps(vbias, _mm_mul_ps(vi1x4567, vk01));
215 __m128 vo2p0 = _mm_add_ps(vbias, _mm_mul_ps(vi2x4567, vk01));
216 __m128 vo3p0 = _mm_add_ps(vbias, _mm_mul_ps(vi3x4567, vk01));
217 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x4567, vk11));
218 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x4567, vk11));
219 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x4567, vk11));
220 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x4567, vk11));
221 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x4567, vk21));
222 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x4567, vk21));
223 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x4567, vk21));
224 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x4567, vk21));
225
226 const __m128 vi0x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi0x4567), _mm_castps_si128(vi0x0123), 12));
227 const __m128 vi1x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi1x4567), _mm_castps_si128(vi1x0123), 12));
228 const __m128 vi2x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi2x4567), _mm_castps_si128(vi2x0123), 12));
229 const __m128 vi3x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi3x4567), _mm_castps_si128(vi3x0123), 12));
230 const __m128 vi4x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi4x4567), _mm_castps_si128(vi4x0123), 12));
231 const __m128 vi5x3456 = _mm_castsi128_ps(_mm_alignr_epi8(_mm_castps_si128(vi5x4567), _mm_castps_si128(vi5x0123), 12));
232
233 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x3456, vk00));
234 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x3456, vk00));
235 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x3456, vk00));
236 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x3456, vk00));
237 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x3456, vk10));
238 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x3456, vk10));
239 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x3456, vk10));
240 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x3456, vk10));
241 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x3456, vk20));
242 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x3456, vk20));
243 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x3456, vk20));
244 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x3456, vk20));
245
246 const __m128i vzero = _mm_setzero_si128();
247 const __m128 vi0x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi0x4567), 4));
248 const __m128 vi1x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi1x4567), 4));
249 const __m128 vi2x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi2x4567), 4));
250 const __m128 vi3x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi3x4567), 4));
251 const __m128 vi4x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi4x4567), 4));
252 const __m128 vi5x5678 = _mm_castsi128_ps(_mm_alignr_epi8(vzero, _mm_castps_si128(vi5x4567), 4));
253
254 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi0x5678, vk02));
255 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi1x5678, vk02));
256 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi2x5678, vk02));
257 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi3x5678, vk02));
258 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi1x5678, vk12));
259 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi2x5678, vk12));
260 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi3x5678, vk12));
261 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi4x5678, vk12));
262 vo0p0 = _mm_add_ps(vo0p0, _mm_mul_ps(vi2x5678, vk22));
263 vo1p0 = _mm_add_ps(vo1p0, _mm_mul_ps(vi3x5678, vk22));
264 vo2p0 = _mm_add_ps(vo2p0, _mm_mul_ps(vi4x5678, vk22));
265 vo3p0 = _mm_add_ps(vo3p0, _mm_mul_ps(vi5x5678, vk22));
266
267
268 __m128 vo0 = _mm_max_ps(vo0p0, vmin);
269 __m128 vo1 = _mm_max_ps(vo1p0, vmin);
270 __m128 vo2 = _mm_max_ps(vo2p0, vmin);
271 __m128 vo3 = _mm_max_ps(vo3p0, vmin);
272
273 vo0 = _mm_min_ps(vo0, vmax);
274 vo1 = _mm_min_ps(vo1, vmax);
275 vo2 = _mm_min_ps(vo2, vmax);
276 vo3 = _mm_min_ps(vo3, vmax);
277
278 if XNN_LIKELY(w == 4 * sizeof(float)) {
279 _mm_storeu_ps(o3, vo3);
280 o3 += 4;
281 _mm_storeu_ps(o2, vo2);
282 o2 += 4;
283 _mm_storeu_ps(o1, vo1);
284 o1 += 4;
285 _mm_storeu_ps(o0, vo0);
286 o0 += 4;
287 } else {
288 if (w & (2 * sizeof(float))) {
289 _mm_storel_pi((__m64*) o3, vo3);
290 o3 += 2;
291 _mm_storel_pi((__m64*) o2, vo2);
292 o2 += 2;
293 _mm_storel_pi((__m64*) o1, vo1);
294 o1 += 2;
295 _mm_storel_pi((__m64*) o0, vo0);
296 o0 += 2;
297
298 vo0 = _mm_movehl_ps(vo0, vo0);
299 vo1 = _mm_movehl_ps(vo1, vo1);
300 vo2 = _mm_movehl_ps(vo2, vo2);
301 vo3 = _mm_movehl_ps(vo3, vo3);
302 }
303 if (w & (1 * sizeof(float))) {
304 _mm_store_ss(o3, vo3);
305 o3 += 1;
306 _mm_store_ss(o2, vo2);
307 o2 += 1;
308 _mm_store_ss(o1, vo1);
309 o1 += 1;
310 _mm_store_ss(o0, vo0);
311 o0 += 1;
312 }
313 }
314 }
315
316 i0 = (const float*) ((uintptr_t) i4 - input_decrement);
317 i1 = (const float*) ((uintptr_t) i5 - input_decrement);
318 i2 = (const float*) ((uintptr_t) i1 + input_width);
319 i3 = (const float*) ((uintptr_t) i2 + input_width);
320 i4 = (const float*) ((uintptr_t) i3 + input_width);
321 i5 = (const float*) ((uintptr_t) i4 + input_width);
322
323 o0 = o3;
324 o1 = (float*) ((uintptr_t) o0 + input_width);
325 o2 = (float*) ((uintptr_t) o1 + input_width);
326 o3 = (float*) ((uintptr_t) o2 + input_width);
327
328 output_height = doz(output_height, 4);
329 } while (output_height != 0);
330 }
331