1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <xmmintrin.h>
11
12#include <xnnpack/dwconv.h>
13#include <xnnpack/math.h>
14
15
16void xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__sse_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
17    size_t input_height,
18    size_t input_width,
19    const float* input,
20    const float* weights,
21    const float* zero,
22    float* output,
23    uint32_t padding_top,
24    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
25{
26  assert(input_height != 0);
27  assert(input_width != 0);
28  assert(input_width % sizeof(float) == 0);
29  assert(padding_top >= 0);
30  assert(padding_top <= 1);
31
32  const __m128 vmask_even = _mm_load_ps((const float*) params->sse.mask_even);
33  const __m128 vmask_odd  = _mm_load_ps((const float*) params->sse.mask_odd);
34  const __m128 vmax = _mm_load_ps(params->sse.max);
35  const __m128 vmin = _mm_load_ps(params->sse.min);
36
37  const __m128 vbias = _mm_load1_ps(weights);
38  const __m128 vk00 = _mm_load1_ps(weights + 1);
39  const __m128 vk01 = _mm_load1_ps(weights + 2);
40  const __m128 vk02 = _mm_load1_ps(weights + 3);
41  const __m128 vk10 = _mm_load1_ps(weights + 4);
42  const __m128 vk11 = _mm_load1_ps(weights + 5);
43  const __m128 vk12 = _mm_load1_ps(weights + 6);
44  const __m128 vk20 = _mm_load1_ps(weights + 7);
45  const __m128 vk21 = _mm_load1_ps(weights + 8);
46  const __m128 vk22 = _mm_load1_ps(weights + 9);
47
48  const size_t input_decrement = round_down_po2(input_width, 4 /* SIMD output width */ * 2 /* subsampling */ * sizeof(float));
49  $if ROW_TILE > 1:
50    const size_t output_width = round_down_po2((input_width + (2 /* padding */ - 3 /* kernel size */ + 2 /* subsampling */) * sizeof(float)) / 2, sizeof(float));
51
52  const float* i0 = (const float*) ((uintptr_t) input - ((-padding_top) & input_width));
53  const float* i1 = (const float*) ((uintptr_t) i0 + input_width);
54  if XNN_UNPREDICTABLE(padding_top != 0) {
55    i0 = zero;
56  }
57  $for M in range(2, 1 + 2 * ROW_TILE):
58    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
59
60  float* o0 = output;
61  $for M in range(1, ROW_TILE):
62    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
63
64  size_t padded_input_height = input_height + padding_top + 1 /* padding bottom */;
65  size_t output_height = (padded_input_height - 3 /* kernel size */ + 2 /* subsampling */) / 2;
66  do {
67    $for M in range(2, 1 + 2 * ROW_TILE):
68      if XNN_UNPREDICTABLE(padded_input_height < ${2 + M}) {
69        i${M} = zero;
70        $if M % 2 == 1:
71          o${(M - 1) / 2} = o${(M - 1) / 2 - 1};
72      }
73
74    $for M in range(1 + 2 * ROW_TILE):
75      __m128 vi${M}x7531 = _mm_setzero_ps();
76
77    size_t w = input_width;
78    for (; w >= 8 * sizeof(float); w -= 8 * sizeof(float)) {
79      $for M in range(1 + 2 * ROW_TILE):
80        const __m128 vi${M}x89AB = _mm_loadu_ps(i${M});
81        const __m128 vi${M}xCDEF = _mm_loadu_ps(i${M} + 4);
82        i${M} += 8;
83
84      $for M in range(1 + 2 * ROW_TILE):
85        const __m128 vi${M}x8ACE = _mm_shuffle_ps(vi${M}x89AB, vi${M}xCDEF, _MM_SHUFFLE(2, 0, 2, 0));
86        const __m128 vi${M}x9BDF = _mm_shuffle_ps(vi${M}x89AB, vi${M}xCDEF, _MM_SHUFFLE(3, 1, 3, 1));
87
88      $for K in range(3):
89        $for M in range(ROW_TILE):
90          $if K == 0:
91            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1));
92          $elif K < ACCUMULATORS:
93            __m128 vo${M}p${K} = _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1);
94          $else:
95            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1));
96
97      $for M in range(1 + 2 * ROW_TILE):
98        const __m128 vi${M}xF9BD = _mm_shuffle_ps(vi${M}x9BDF, vi${M}x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
99
100      $for K in range(3):
101        $for M in range(ROW_TILE):
102          $if K+3 < ACCUMULATORS:
103            __m128 vo${M}p${K+3} = _mm_mul_ps(vi${2*M+K}x9BDF, vk${K}2);
104          $else:
105            vo${M}p${(K+3) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+3) % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x9BDF, vk${K}2));
106
107      $for M in range(1 + 2 * ROW_TILE):
108        const __m128 vi${M}x7BDF = _mm_move_ss(vi${M}xF9BD, vi${M}x7531);
109
110      $for M in range(1 + 2 * ROW_TILE):
111        vi${M}x7531 = vi${M}xF9BD;
112
113      $for K in range(3):
114        $for M in range(ROW_TILE):
115          vo${M}p${(K+6) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+6) % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x7BDF, vk${K}0));
116
117      $if ACCUMULATORS > 1:
118        $ACC_SLICE = 1
119        $while ACC_SLICE < ACCUMULATORS:
120          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
121            $if A + ACC_SLICE < ACCUMULATORS:
122              $for M in range(ROW_TILE):
123                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
124          $ACC_SLICE *= 2
125
126      $for M in range(ROW_TILE):
127        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
128
129      $for M in range(ROW_TILE):
130        vo${M} = _mm_min_ps(vo${M}, vmax);
131
132      $for M in reversed(range(ROW_TILE)):
133        _mm_storeu_ps(o${M}, vo${M});
134        o${M} += 4;
135    }
136    // Potentially process the last block of 0..7 pixels.
137    assert(w < 8 * sizeof(float));
138    if XNN_LIKELY(w != 0) {
139      $for M in range(1 + 2 * ROW_TILE):
140        const __m128 vi${M}x89AB = _mm_loadu_ps(i${M});
141        const __m128 vi${M}xCDEF = _mm_loadu_ps(i${M} + 4);
142
143      $for M in range(1 + 2 * ROW_TILE):
144        const __m128 vi${M}x8ACE = _mm_and_ps(vmask_even, _mm_shuffle_ps(vi${M}x89AB, vi${M}xCDEF, _MM_SHUFFLE(2, 0, 2, 0)));
145        const __m128 vi${M}x9BDF = _mm_and_ps(vmask_odd,  _mm_shuffle_ps(vi${M}x89AB, vi${M}xCDEF, _MM_SHUFFLE(3, 1, 3, 1)));
146
147      $for K in range(3):
148        $for M in range(ROW_TILE):
149          $if K == 0:
150            __m128 vo${M}p0 = _mm_add_ps(vbias, _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1));
151          $elif K < ACCUMULATORS:
152            __m128 vo${M}p${K} = _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1);
153          $else:
154            vo${M}p${K % ACCUMULATORS} = _mm_add_ps(vo${M}p${K % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x8ACE, vk${K}1));
155
156      $for M in range(1 + 2 * ROW_TILE):
157        const __m128 vi${M}xF9BD = _mm_shuffle_ps(vi${M}x9BDF, vi${M}x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
158
159      $for K in range(3):
160        $for M in range(ROW_TILE):
161          $if K+3 < ACCUMULATORS:
162            __m128 vo${M}p${K+3} = _mm_mul_ps(vi${2*M+K}x9BDF, vk${K}2);
163          $else:
164            vo${M}p${(K+3) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+3) % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x9BDF, vk${K}2));
165
166      $for M in range(1 + 2 * ROW_TILE):
167        const __m128 vi${M}x7BDF = _mm_move_ss(vi${M}xF9BD, vi${M}x7531);
168
169      $for M in range(1 + 2 * ROW_TILE):
170        vi${M}x7531 = vi${M}xF9BD;
171
172      $for K in range(3):
173        $for M in range(ROW_TILE):
174          vo${M}p${(K+6) % ACCUMULATORS} = _mm_add_ps(vo${M}p${(K+6) % ACCUMULATORS}, _mm_mul_ps(vi${2*M+K}x7BDF, vk${K}0));
175
176      $if ACCUMULATORS > 1:
177        $ACC_SLICE = 1
178        $while ACC_SLICE < ACCUMULATORS:
179          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
180            $if A + ACC_SLICE < ACCUMULATORS:
181              $for M in range(ROW_TILE):
182                vo${M}p${A} = _mm_add_ps(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
183          $ACC_SLICE *= 2
184
185      $for M in range(ROW_TILE):
186        __m128 vo${M} = _mm_max_ps(vo${M}p0, vmin);
187
188      $for M in range(ROW_TILE):
189        vo${M} = _mm_min_ps(vo${M}, vmax);
190
191      if (w == 7 * sizeof(float)) {
192        $for M in reversed(range(ROW_TILE)):
193          _mm_storeu_ps(o${M}, vo${M});
194          o${M} += 4;
195      } else {
196        w += 1 * sizeof(float);
197        if (w & (4 * sizeof(float))) {
198          $for M in reversed(range(ROW_TILE)):
199            _mm_storel_pi((__m64*) o${M}, vo${M});
200            o${M} += 2;
201
202          $for M in range(ROW_TILE):
203            vo${M} = _mm_movehl_ps(vo${M}, vo${M});
204        }
205        if (w & (2 * sizeof(float))) {
206          $for M in reversed(range(ROW_TILE)):
207            _mm_store_ss(o${M}, vo${M});
208            o${M} += 1;
209        }
210      }
211    }
212
213    i0 = (const float*) ((uintptr_t) i${2 * ROW_TILE} - input_decrement);
214    $for M in range(1, 1 + 2 * ROW_TILE):
215      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
216
217    $if ROW_TILE > 1:
218      o0 = o${ROW_TILE - 1};
219      $for M in range(1, ROW_TILE):
220        o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
221
222    $if ROW_TILE > 1:
223      output_height = doz(output_height, ${ROW_TILE});
224      padded_input_height = doz(padded_input_height, ${ROW_TILE * 2});
225    $else:
226      output_height -= 1;
227      padded_input_height -= 2;
228  } while (output_height != 0);
229}
230