1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <xmmintrin.h>
11
12#include <xnnpack/dwconv.h>
13#include <xnnpack/math.h>
14
15
16void xnn_f32_dwconv2d_chw_ukernel_5x5p2__sse_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
17    size_t input_height,
18    size_t input_width,
19    const float* input,
20    const float* weights,
21    const float* zero,
22    float* output,
23    uint32_t padding_top,
24    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
25{
26  assert(input_height != 0);
27  assert(input_width != 0);
28  assert(input_width % sizeof(float) == 0);
29  assert(padding_top == 2);
30
31  const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
32  const __m128 vmax = _mm_load_ps(params->sse.max);
33  const __m128 vmin = _mm_load_ps(params->sse.min);
34
35  const __m128 vbias = _mm_load1_ps(weights);
36  $for R in range(5):
37    $for S in range(5):
38      const __m128 vk${R}${S} = _mm_load1_ps(weights + ${R*5+S+1});
39
40  const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
41
42  const float* i0 = zero;
43  const float* i1 = zero;
44  const float* i2 = input;
45  $for M in range(3, 4 + ROW_TILE):
46    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
47
48  float* o0 = output;
49  $for M in range(1, ROW_TILE):
50    float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
51
52  size_t output_height = input_height;
53  do {
54    $for M in range(2, 3 + ROW_TILE):
55      if XNN_UNPREDICTABLE(output_height < ${M}) {
56        i${M+1} = zero;
57        $if M <= ROW_TILE:
58          o${M-1} = o${M-2};
59      }
60
61    $for M in range(4 + ROW_TILE):
62      __m128 vi${M}x3012 = _mm_setzero_ps();
63
64    $for M in range(4 + ROW_TILE):
65      __m128 vi${M}x4567 = _mm_loadu_ps(i${M});
66      i${M} += 4;
67
68    size_t w = input_width;
69    for (; w > 8 * sizeof(float); w -= 4 * sizeof(float)) {
70      $for K in range(5):
71        $for M in range(ROW_TILE):
72          $if K == 0:
73            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
74          $elif K < ACCUMULATORS:
75            __m128 vo${M}p${K} = _mm_mul_ps(vi${M+K}x4567, vk${K}2);
76          $else:
77            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
78
79      $for M in range(4 + ROW_TILE):
80        const __m128 vi${M}x7456 = _mm_shuffle_ps(vi${M}x4567, vi${M}x4567, _MM_SHUFFLE(2, 1, 0, 3));
81
82      $for M in range(4 + ROW_TILE):
83        const __m128 vi${M}x89AB = _mm_loadu_ps(i${M});
84        i${M} += 4;
85
86      $for M in range(4 + ROW_TILE):
87        const __m128 vi${M}x3456 = _mm_move_ss(vi${M}x7456, vi${M}x3012);
88
89      $for K in range(5):
90        $for M in range(ROW_TILE):
91          vo${M}p${(K+5) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+5) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x3456, vk${K}1));
92
93      $for M in range(4 + ROW_TILE):
94        const __m128 vi${M}x2345 = _mm_shuffle_ps(vi${M}x3012, vi${M}x7456, _MM_SHUFFLE(2, 1, 0, 3));
95        vi${M}x3012 = vi${M}x7456;
96
97      $for M in range(4 + ROW_TILE):
98        const __m128 vi${M}x8567 = _mm_move_ss(vi${M}x4567, vi${M}x89AB);
99        vi${M}x4567 = vi${M}x89AB;
100
101      $for K in range(5):
102        $for M in range(ROW_TILE):
103          vo${M}p${(K+10) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+10) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x2345, vk${K}0));
104
105      $for M in range(4 + ROW_TILE):
106        const __m128 vi${M}x5678 = _mm_shuffle_ps(vi${M}x8567, vi${M}x8567, _MM_SHUFFLE(0, 3, 2, 1));
107
108      $for K in range(5):
109        $for M in range(ROW_TILE):
110          vo${M}p${(K+15) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+15) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x5678, vk${K}3));
111
112      $for M in range(4 + ROW_TILE):
113        const __m128 vi${M}x6789 = _mm_shuffle_ps(vi${M}x5678, vi${M}x89AB, _MM_SHUFFLE(1, 0, 2, 1));
114
115      $for K in range(5):
116        $for M in range(ROW_TILE):
117          vo${M}p${(K+20) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+20) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x6789, vk${K}4));
118
119      $if ACCUMULATORS > 1:
120        $ACC_SLICE = 1
121        $while ACC_SLICE < ACCUMULATORS:
122          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
123            $if A + ACC_SLICE < ACCUMULATORS:
124              $for M in range(ROW_TILE):
125                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
126          $ACC_SLICE *= 2
127
128      $for M in range(ROW_TILE):
129        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
130
131      $for M in range(ROW_TILE):
132        vo${M} = _mm_min_ps(vo${M}, vmax);
133
134      $for M in reversed(range(ROW_TILE)):
135        _mm_storeu_ps(o${M}, vo${M});
136        o${M} += 4;
137    }
138    // Always process the last block of 5..8 pixels.
139    if XNN_LIKELY(w > 4 * sizeof(float)) {
140      $for K in range(5):
141        $for M in range(ROW_TILE):
142          $if K == 0:
143            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
144          $elif K < ACCUMULATORS:
145            __m128 vo${M}p${K} = _mm_mul_ps(vi${M+K}x4567, vk${K}2);
146          $else:
147            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
148
149      $for M in range(4 + ROW_TILE):
150        const __m128 vi${M}x7456 = _mm_shuffle_ps(vi${M}x4567, vi${M}x4567, _MM_SHUFFLE(2, 1, 0, 3));
151
152      $for M in range(4 + ROW_TILE):
153        const __m128 vi${M}x89AB = _mm_and_ps(_mm_loadu_ps(i${M}), vmask);
154        i${M} += 4;
155
156      $for M in range(4 + ROW_TILE):
157        const __m128 vi${M}x3456 = _mm_move_ss(vi${M}x7456, vi${M}x3012);
158
159      $for K in range(5):
160        $for M in range(ROW_TILE):
161          vo${M}p${(K+5) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+5) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x3456, vk${K}1));
162
163      $for M in range(4 + ROW_TILE):
164        const __m128 vi${M}x2345 = _mm_shuffle_ps(vi${M}x3012, vi${M}x7456, _MM_SHUFFLE(2, 1, 0, 3));
165        vi${M}x3012 = vi${M}x7456;
166
167      $for M in range(4 + ROW_TILE):
168        const __m128 vi${M}x8567 = _mm_move_ss(vi${M}x4567, vi${M}x89AB);
169        vi${M}x4567 = vi${M}x89AB;
170
171      $for K in range(5):
172        $for M in range(ROW_TILE):
173          vo${M}p${(K+10) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+10) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x2345, vk${K}0));
174
175      $for M in range(4 + ROW_TILE):
176        const __m128 vi${M}x5678 = _mm_shuffle_ps(vi${M}x8567, vi${M}x8567, _MM_SHUFFLE(0, 3, 2, 1));
177
178      $for K in range(5):
179        $for M in range(ROW_TILE):
180          vo${M}p${(K+15) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+15) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x5678, vk${K}3));
181
182      $for M in range(4 + ROW_TILE):
183        const __m128 vi${M}x6789 = _mm_shuffle_ps(vi${M}x5678, vi${M}x89AB, _MM_SHUFFLE(1, 0, 2, 1));
184
185      $for K in range(5):
186        $for M in range(ROW_TILE):
187          vo${M}p${(K+20) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+20) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x6789, vk${K}4));
188
189      $if ACCUMULATORS > 1:
190        $ACC_SLICE = 1
191        $while ACC_SLICE < ACCUMULATORS:
192          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
193            $if A + ACC_SLICE < ACCUMULATORS:
194              $for M in range(ROW_TILE):
195                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
196          $ACC_SLICE *= 2
197
198      $for M in range(ROW_TILE):
199        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
200
201      $for M in range(ROW_TILE):
202        vo${M} = _mm_min_ps(vo${M}, vmax);
203
204      $for M in reversed(range(ROW_TILE)):
205        _mm_storeu_ps(o${M}, vo${M});
206        o${M} += 4;
207
208      w -= 4 * sizeof(float);
209    }
210    assert(w >= 1 * sizeof(float));
211    assert(w <= 4 * sizeof(float));
212    {
213      $for M in range(4 + ROW_TILE):
214        vi${M}x4567 = _mm_and_ps(vi${M}x4567, vmask);
215
216      $for K in range(5):
217        $for M in range(ROW_TILE):
218          $if K == 0:
219            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
220          $elif K < ACCUMULATORS:
221            __m128 vo${M}p${K} = _mm_mul_ps(vi${M+K}x4567, vk${K}2);
222          $else:
223            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x4567, vk${K}2));
224
225      $for M in range(4 + ROW_TILE):
226        const __m128 vi${M}x7456 = _mm_shuffle_ps(vi${M}x4567, vi${M}x4567, _MM_SHUFFLE(2, 1, 0, 3));
227
228      $for M in range(4 + ROW_TILE):
229        const __m128 vi${M}x3456 = _mm_move_ss(vi${M}x7456, vi${M}x3012);
230
231      $for K in range(5):
232        $for M in range(ROW_TILE):
233          vo${M}p${(K+5) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+5) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x3456, vk${K}1));
234
235      $for M in range(4 + ROW_TILE):
236        const __m128 vi${M}x2345 = _mm_shuffle_ps(vi${M}x3012, vi${M}x7456, _MM_SHUFFLE(2, 1, 0, 3));
237
238      const __m128 vzero = _mm_setzero_ps();
239      $for M in range(4 + ROW_TILE):
240        const __m128 vi${M}x8567 = _mm_move_ss(vi${M}x4567, vzero);
241
242      $for K in range(5):
243        $for M in range(ROW_TILE):
244          vo${M}p${(K+10) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+10) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x2345, vk${K}0));
245
246      $for M in range(4 + ROW_TILE):
247        const __m128 vi${M}x5678 = _mm_shuffle_ps(vi${M}x8567, vi${M}x8567, _MM_SHUFFLE(0, 3, 2, 1));
248
249      $for K in range(5):
250        $for M in range(ROW_TILE):
251          vo${M}p${(K+15) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+15) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x5678, vk${K}3));
252
253      $for M in range(4 + ROW_TILE):
254        const __m128 vi${M}x6789 = _mm_shuffle_ps(vi${M}x5678, vzero, _MM_SHUFFLE(1, 0, 2, 1));
255
256      $for K in range(5):
257        $for M in range(ROW_TILE):
258          vo${M}p${(K+20) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+20) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x6789, vk${K}4));
259
260      $if ACCUMULATORS > 1:
261        $ACC_SLICE = 1
262        $while ACC_SLICE < ACCUMULATORS:
263          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
264            $if A + ACC_SLICE < ACCUMULATORS:
265              $for M in range(ROW_TILE):
266                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
267          $ACC_SLICE *= 2
268
269      $for M in range(ROW_TILE):
270        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
271
272      $for M in range(ROW_TILE):
273        vo${M} = _mm_min_ps(vo${M}, vmax);
274
275      if XNN_LIKELY(w & (4 * sizeof(float))) {
276        $for M in reversed(range(ROW_TILE)):
277          _mm_storeu_ps(o${M}, vo${M});
278          o${M} += 4;
279      } else {
280        if (w & (2 * sizeof(float))) {
281          $for M in reversed(range(ROW_TILE)):
282            _mm_storel_pi((__m64*) o${M}, vo${M});
283            o${M} += 2;
284
285          $for M in range(ROW_TILE):
286            vo${M} = _mm_movehl_ps(vo${M}, vo${M});
287        }
288        if (w & (1 * sizeof(float))) {
289          $for M in reversed(range(ROW_TILE)):
290            _mm_store_ss(o${M}, vo${M});
291            o${M} += 1;
292        }
293      }
294    }
295
296    i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_decrement);
297    i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_decrement);
298    $for M in range(2, 4 + ROW_TILE):
299      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
300
301    $if ROW_TILE > 1:
302      o0 = o${ROW_TILE - 1};
303      $for M in range(1, ROW_TILE):
304        o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
305
306    $if ROW_TILE > 1:
307      output_height = doz(output_height, ${ROW_TILE});
308  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
309}
310