1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <xmmintrin.h>
11
12#include <xnnpack/dwconv.h>
13#include <xnnpack/math.h>
14
15
16void xnn_f32_dwconv2d_chw_ukernel_3x3p1__sse_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
17    size_t input_height,
18    size_t input_width,
19    const float* input,
20    const float* weights,
21    const float* zero,
22    float* output,
23    uint32_t padding_top,
24    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
25{
26  assert(input_height != 0);
27  assert(input_width != 0);
28  assert(input_width % sizeof(float) == 0);
29  assert(padding_top == 1);
30
31  const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
32  const __m128 vmax = _mm_load_ps(params->sse.max);
33  const __m128 vmin = _mm_load_ps(params->sse.min);
34
35  const __m128 vbias = _mm_load1_ps(weights);
36  const __m128 vk00 = _mm_load1_ps(weights + 1);
37  const __m128 vk01 = _mm_load1_ps(weights + 2);
38  const __m128 vk02 = _mm_load1_ps(weights + 3);
39  const __m128 vk10 = _mm_load1_ps(weights + 4);
40  const __m128 vk11 = _mm_load1_ps(weights + 5);
41  const __m128 vk12 = _mm_load1_ps(weights + 6);
42  const __m128 vk20 = _mm_load1_ps(weights + 7);
43  const __m128 vk21 = _mm_load1_ps(weights + 8);
44  const __m128 vk22 = _mm_load1_ps(weights + 9);
45
46  const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
47
48  const float* i0 = zero;
49  const float* i1 = input;
50  $for M in range(2, 2 + ROW_TILE):
51    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
52
53  float* o0 = output;
54  $for M in range(1, ROW_TILE):
55    float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
56
57  size_t output_height = input_height;
58  do {
59    $for M in range(2, 2 + ROW_TILE):
60      if XNN_UNPREDICTABLE(output_height < ${M}) {
61        i${M} = zero;
62        $if M <= ROW_TILE:
63          o${M-1} = o${M-2};
64      }
65
66    $for M in range(2 + ROW_TILE):
67      // vi${M}x3012 = ( vi${M}2, vi${M}1, vi{M}0, vi{M}3 )
68      __m128 vi${M}x3012 = _mm_setzero_ps();
69
70    $for M in range(2 + ROW_TILE):
71      __m128 vi${M}x4567 = _mm_loadu_ps(i${M});
72      i${M} += 4;
73
74    size_t w = input_width;
75    for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
76      $for M in range(2 + ROW_TILE):
77        // vi${M}x89AB = ( vi${M}B, vi${M}A, vi${M}9, vi${M}8 )
78        const __m128 vi${M}x89AB = _mm_loadu_ps(i${M});
79        i${M} += 4;
80
81      $for M in range(2 + ROW_TILE):
82        // vi${M}x7456 = ( vi${M}6, vi${M}5, vi${M}4, vi${M}7 )
83        const __m128 vi${M}x7456 = _mm_shuffle_ps(vi${M}x4567, vi${M}x4567, _MM_SHUFFLE(2, 1, 0, 3));
84
85      $for K in range(3):
86        $for M in range(ROW_TILE):
87          $if K == 0:
88            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${M+K}x4567, vk${K}1));
89          $elif K < ACCUMULATORS:
90            __m128 vo${M}p${K} = _mm_mul_ps(vi${M+K}x4567, vk${K}1);
91          $else:
92            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x4567, vk${K}1));
93
94      $for M in range(2 + ROW_TILE):
95        // vi${M}x3456 = ( vi${M}6, vi${M}5, vi${M}4, vi${M}3 )
96        const __m128 vi${M}x3456 = _mm_move_ss(vi${M}x7456, vi${M}x3012);
97
98      $for K in range(3):
99        $for M in range(ROW_TILE):
100          $if K+3 < ACCUMULATORS:
101            __m128 vo${M}p${K+3} = _mm_mul_ps(vi${M+K}x3456, vk${K}0);
102          $else:
103            vo${M}p${(K+3) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+3) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x3456, vk${K}0));
104
105      $for M in range(2 + ROW_TILE):
106        vi${M}x3012 = vi${M}x7456;
107
108      $for M in range(2 + ROW_TILE):
109        // vi${M}x8567 = ( vi${M}7, vi${M}6, vi${M}5, vi${M}8 )
110        const __m128 vi${M}x8567 = _mm_move_ss(vi${M}x4567, vi${M}x89AB);
111
112      $for M in range(2 + ROW_TILE):
113        // vi${M}x5678 = ( vi${M}8, vi${M}7, vi${M}6, vi${M}5 )
114        const __m128 vi${M}x5678 = _mm_shuffle_ps(vi${M}x8567, vi${M}x8567, _MM_SHUFFLE(0, 3, 2, 1));
115
116      $for K in range(3):
117        $for M in range(ROW_TILE):
118          vo${M}p${(K+6) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+6) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x5678, vk${K}2));
119
120      $for M in range(2 + ROW_TILE):
121        vi${M}x4567 = vi${M}x89AB;
122
123      $if ACCUMULATORS > 1:
124        $ACC_SLICE = 1
125        $while ACC_SLICE < ACCUMULATORS:
126          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
127            $if A + ACC_SLICE < ACCUMULATORS:
128              $for M in range(ROW_TILE):
129                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
130          $ACC_SLICE *= 2
131
132      $for M in range(ROW_TILE):
133        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
134
135      $for M in range(ROW_TILE):
136        vo${M} = _mm_min_ps(vo${M}, vmax);
137
138      $for M in reversed(range(ROW_TILE)):
139        _mm_storeu_ps(o${M}, vo${M});
140        o${M} += 4;
141    }
142    // Always process the last block of 1..4 pixels.
143    assert(w >= 1 * sizeof(float));
144    assert(w <= 4 * sizeof(float));
145    {
146      $for M in range(2 + ROW_TILE):
147        vi${M}x4567 = _mm_and_ps(vmask, vi${M}x4567);
148
149      $for M in range(2 + ROW_TILE):
150        // vi${M}x7456 = ( vi${M}6, vi${M}5, vi${M}4, vi${M}7 )
151        const __m128 vi${M}x7456 = _mm_shuffle_ps(vi${M}x4567, vi${M}x4567, _MM_SHUFFLE(2, 1, 0, 3));
152
153      $for K in range(3):
154        $for M in range(ROW_TILE):
155          $if K == 0:
156            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${M+K}x4567, vk${K}1));
157          $elif K < ACCUMULATORS:
158            __m128 vo${M}p${K} = _mm_mul_ps(vi${M+K}x4567, vk${K}1);
159          $else:
160            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x4567, vk${K}1));
161
162      $for M in range(2 + ROW_TILE):
163        // vi${M}x3456 = ( vi${M}6, vi${M}5, vi${M}4, vi${M}3 )
164        const __m128 vi${M}x3456 = _mm_move_ss(vi${M}x7456, vi${M}x3012);
165
166      $for K in range(3):
167        $for M in range(ROW_TILE):
168          $if K+3 < ACCUMULATORS:
169            __m128 vo${M}p${K+3} = _mm_mul_ps(vi${M+K}x3456, vk${K}0);
170          $else:
171            vo${M}p${(K+3) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+3) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x3456, vk${K}0));
172
173      const __m128 vzero = _mm_setzero_ps();
174      $for M in range(2 + ROW_TILE):
175        // vi${M}x8567 = ( vi${M}7, vi${M}6, vi${M}5, 0.0 )
176        const __m128 vi${M}x8567 = _mm_move_ss(vi${M}x4567, vzero);
177
178      $for M in range(2 + ROW_TILE):
179        // vi${M}x5678 = ( vi${M}8, vi${M}7, vi${M}6, vi${M}5 )
180        const __m128 vi${M}x5678 = _mm_shuffle_ps(vi${M}x8567, vi${M}x8567, _MM_SHUFFLE(0, 3, 2, 1));
181
182      $for K in range(3):
183        $for M in range(ROW_TILE):
184          vo${M}p${(K+6) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+6) % ACCUMULATORS}, _mm_mul_ps(vi${M+K}x5678, vk${K}2));
185
186      $if ACCUMULATORS > 1:
187        $ACC_SLICE = 1
188        $while ACC_SLICE < ACCUMULATORS:
189          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
190            $if A + ACC_SLICE < ACCUMULATORS:
191              $for M in range(ROW_TILE):
192                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
193          $ACC_SLICE *= 2
194
195      $for M in range(ROW_TILE):
196        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
197
198      $for M in range(ROW_TILE):
199        vo${M} = _mm_min_ps(vo${M}, vmax);
200
201      if XNN_LIKELY(w == 4 * sizeof(float)) {
202        $for M in reversed(range(ROW_TILE)):
203          _mm_storeu_ps(o${M}, vo${M});
204          o${M} += 4;
205      } else {
206        if (w & (2 * sizeof(float))) {
207          $for M in reversed(range(ROW_TILE)):
208            _mm_storel_pi((__m64*) o${M}, vo${M});
209            o${M} += 2;
210
211          $for M in range(ROW_TILE):
212            vo${M} = _mm_movehl_ps(vo${M}, vo${M});
213        }
214        if (w & (1 * sizeof(float))) {
215          $for M in reversed(range(ROW_TILE)):
216            _mm_store_ss(o${M}, vo${M});
217            o${M} += 1;
218        }
219      }
220    }
221
222    i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_decrement);
223    i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_decrement);
224    $for M in range(2, 2 + ROW_TILE):
225      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
226
227    $if ROW_TILE > 1:
228      o0 = o${ROW_TILE - 1};
229      $for M in range(1, ROW_TILE):
230        o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
231
232    $if ROW_TILE > 1:
233      output_height = doz(output_height, ${ROW_TILE});
234  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
235}
236