1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <xnnpack/dwconv.h>
11#include <xnnpack/math.h>
12
13
14void xnn_f32_dwconv2d_chw_ukernel_5x5p2__scalar_${ROW_TILE}x1${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
15    size_t input_height,
16    size_t input_width,
17    const float* input,
18    const float* weights,
19    const float* zero,
20    float* output,
21    uint32_t padding_top,
22    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
23{
24  assert(input_height != 0);
25  assert(input_width != 0);
26  assert(input_width % sizeof(float) == 0);
27  assert(padding_top == 2);
28
29  const float vmin = params->scalar.min;
30  const float vmax = params->scalar.max;
31
32  const float vbias = weights[0];
33  $for R in range(5):
34    $for S in range(5):
35      const float vk${R}${S} = weights[${R*5+S+1}];
36
37  const float* i0 = zero;
38  const float* i1 = zero;
39  const float* i2 = input;
40  $for M in range(3, 4 + ROW_TILE):
41    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
42
43  float* o0 = output;
44  $for M in range(1, ROW_TILE):
45    float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
46
47  size_t output_height = input_height;
48  do {
49    $for M in range(2, 3 + ROW_TILE):
50      if XNN_UNPREDICTABLE(output_height < ${M}) {
51        i${M+1} = zero;
52        $if M <= ROW_TILE:
53          o${M-1} = o${M-2};
54      }
55
56    $for M in range(4 + ROW_TILE):
57      float vi${M}x0 = 0.0f;
58
59    $for M in range(4 + ROW_TILE):
60      float vi${M}x1 = 0.0f;
61
62    $for M in range(4 + ROW_TILE):
63      float vi${M}x2 = *i${M}++;
64
65    size_t w = input_width;
66    if (w > 1 * sizeof(float)) {
67      $for M in range(4 + ROW_TILE):
68        float vi${M}x3 = *i${M}++;
69
70      for (; w > 2 * sizeof(float); w -= 1 * sizeof(float)) {
71        $for M in range(4 + ROW_TILE):
72          const float vi${M}x4 = *i${M}++;
73
74        $for K in range(5):
75          $for M in range(ROW_TILE):
76            $if K == 0:
77              float vo${M}p0 = vbias + vi${M+K}x0 * vk${K}0;
78            $elif K < ACCUMULATORS:
79              float vo${M}p${K} = vi${M+K}x0 * vk${K}0;
80            $else:
81              vo${M}p${K % ACCUMULATORS} += vi${M+K}x0 * vk${K}0;
82
83        $for M in range(4 + ROW_TILE):
84          vi${M}x0 = vi${M}x1;
85
86        $for K in range(5):
87          $for M in range(ROW_TILE):
88            $if K+5 < ACCUMULATORS:
89              float vo${M}p${K+5} = vi${M+K}x1 * vk${K}1;
90            $else:
91              vo${M}p${(K+5) % ACCUMULATORS} += vi${M+K}x1 * vk${K}1;
92
93        $for M in range(4 + ROW_TILE):
94          vi${M}x1 = vi${M}x2;
95
96        $for K in range(5):
97          $for M in range(ROW_TILE):
98            vo${M}p${(K+10) % ACCUMULATORS} += vi${M+K}x2 * vk${K}2;
99
100        $for M in range(4 + ROW_TILE):
101          vi${M}x2 = vi${M}x3;
102
103        $for K in range(5):
104          $for M in range(ROW_TILE):
105            vo${M}p${(K+15) % ACCUMULATORS} += vi${M+K}x3 * vk${K}3;
106
107        $for M in range(4 + ROW_TILE):
108          vi${M}x3 = vi${M}x4;
109
110        $for K in range(5):
111          $for M in range(ROW_TILE):
112            vo${M}p${(K+20) % ACCUMULATORS} += vi${M+K}x4 * vk${K}4;
113
114        $if ACCUMULATORS > 1:
115          $ACC_SLICE = 1
116          $while ACC_SLICE < ACCUMULATORS:
117            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
118              $if A + ACC_SLICE < ACCUMULATORS:
119                $for M in range(ROW_TILE):
120                  vo${M}p${A} += vo${M}p${A + ACC_SLICE};
121            $ACC_SLICE *= 2
122
123        $for M in range(ROW_TILE):
124          float vo${M} = math_max_f32(vo${M}p0, vmin);
125
126        $for M in range(ROW_TILE):
127          vo${M} = math_min_f32(vo${M}, vmax);
128
129        $for M in reversed(range(ROW_TILE)):
130          *o${M}++ = vo${M};
131      }
132      assert(w == 2 * sizeof(float));
133      {
134        $for K in range(5):
135          $for M in range(ROW_TILE):
136            $if K == 0:
137              float vo${M}p0 = vbias + vi${M+K}x0 * vk${K}0;
138            $elif K < ACCUMULATORS:
139              float vo${M}p${K} = vi${M+K}x0 * vk${K}0;
140            $else:
141              vo${M}p${K % ACCUMULATORS} += vi${M+K}x0 * vk${K}0;
142
143        $for M in range(4 + ROW_TILE):
144          vi${M}x0 = vi${M}x1;
145
146        $for K in range(5):
147          $for M in range(ROW_TILE):
148            $if K+5 < ACCUMULATORS:
149              float vo${M}p${K+5} = vi${M+K}x1 * vk${K}1;
150            $else:
151              vo${M}p${(K+5) % ACCUMULATORS} += vi${M+K}x1 * vk${K}1;
152
153        $for M in range(4 + ROW_TILE):
154          vi${M}x1 = vi${M}x2;
155
156        $for K in range(5):
157          $for M in range(ROW_TILE):
158            vo${M}p${(K+10) % ACCUMULATORS} += vi${M+K}x2 * vk${K}2;
159
160        $for M in range(4 + ROW_TILE):
161          vi${M}x2 = vi${M}x3;
162
163        $for K in range(5):
164          $for M in range(ROW_TILE):
165            vo${M}p${(K+15) % ACCUMULATORS} += vi${M+K}x3 * vk${K}3;
166
167        $if ACCUMULATORS > 1:
168          $ACC_SLICE = 1
169          $while ACC_SLICE < ACCUMULATORS:
170            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
171              $if A + ACC_SLICE < ACCUMULATORS:
172                $for M in range(ROW_TILE):
173                  vo${M}p${A} += vo${M}p${A + ACC_SLICE};
174            $ACC_SLICE *= 2
175
176        $for M in range(ROW_TILE):
177          float vo${M} = math_max_f32(vo${M}p0, vmin);
178
179        $for M in range(ROW_TILE):
180          vo${M} = math_min_f32(vo${M}, vmax);
181
182        $for M in reversed(range(ROW_TILE)):
183          *o${M}++ = vo${M};
184      }
185      w -= 1 * sizeof(float);
186    }
187    assert(w == 1 * sizeof(float));
188    {
189      $for K in range(5):
190        $for M in range(ROW_TILE):
191          $if K == 0:
192            float vo${M}p0 = vbias + vi${M+K}x0 * vk${K}0;
193          $elif K < ACCUMULATORS:
194            float vo${M}p${K} = vi${M+K}x0 * vk${K}0;
195          $else:
196            vo${M}p${K % ACCUMULATORS} += vi${M+K}x0 * vk${K}0;
197
198      $for K in range(5):
199        $for M in range(ROW_TILE):
200          $if K+5 < ACCUMULATORS:
201            float vo${M}p${K+5} = vi${M+K}x1 * vk${K}1;
202          $else:
203            vo${M}p${(K+5) % ACCUMULATORS} += vi${M+K}x1 * vk${K}1;
204
205      $for K in range(5):
206        $for M in range(ROW_TILE):
207          vo${M}p${(K+10) % ACCUMULATORS} += vi${M+K}x2 * vk${K}2;
208
209      $if ACCUMULATORS > 1:
210        $ACC_SLICE = 1
211        $while ACC_SLICE < ACCUMULATORS:
212          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
213            $if A + ACC_SLICE < ACCUMULATORS:
214              $for M in range(ROW_TILE):
215                vo${M}p${A} += vo${M}p${A + ACC_SLICE};
216          $ACC_SLICE *= 2
217
218      $for M in range(ROW_TILE):
219        float vo${M} = math_max_f32(vo${M}p0, vmin);
220
221      $for M in range(ROW_TILE):
222        vo${M} = math_min_f32(vo${M}, vmax);
223
224      $for M in reversed(range(ROW_TILE)):
225        *o${M}++ = vo${M};
226    }
227
228    i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_width);
229    i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_width);
230    $if ROW_TILE > 1:
231      i2 = i${ROW_TILE+1};
232      i3 = i${ROW_TILE+2};
233      i4 = i${ROW_TILE+3};
234      $for M in range(5, 4 + ROW_TILE):
235        i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
236
237    $if ROW_TILE > 1:
238      o0 = o${ROW_TILE - 1};
239      $for M in range(1, ROW_TILE):
240        o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
241
242    $if ROW_TILE > 1:
243      output_height = doz(output_height, ${ROW_TILE});
244  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
245}
246