1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <xnnpack/dwconv.h>
9#include <xnnpack/math.h>
10
11
12void xnn_f32_dwconv2d_chw_ukernel_5x5s2p2__scalar_${ROW_TILE}x1${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
13    size_t input_height,
14    size_t input_width,
15    const float* input,
16    const float* weights,
17    const float* zero,
18    float* output,
19    uint32_t padding_top,
20    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
21{
22  assert(input_height != 0);
23  assert(input_width != 0);
24  assert(input_width % sizeof(float) == 0);
25  assert(padding_top >= 1);
26  assert(padding_top <= 2);
27
28  const float vmax = params->scalar.max;
29  const float vmin = params->scalar.min;
30
31  const float vbias = weights[0];
32  $for R in range(5):
33    $for S in range(5):
34      const float vk${R}${S} = weights[${R*5+S+1}];
35
36  const uint32_t padding_top_less_1 = padding_top - 1;
37
38  const float* i0 = zero;
39  const float* i1 = (const float*) ((uintptr_t) input - ((-padding_top_less_1) & input_width));
40  const float* i2 = (const float*) ((uintptr_t) i1 + input_width);
41  if XNN_UNPREDICTABLE(padding_top_less_1 != 0) {
42    i1 = zero;
43  }
44  $for M in range(3, 3 + 2 * ROW_TILE):
45    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
46
47  $if ROW_TILE > 1:
48    const size_t output_width = round_down_po2((input_width + (2 /* padding */ - 3 /* kernel size */ + 2 /* subsampling */) * sizeof(float)) / 2, sizeof(float));
49
50  float* o0 = output;
51  $for M in range(1, ROW_TILE):
52    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
53
54  size_t padded_input_height = input_height + (padding_top_less_1 + 1) + 2 /* padding bottom */;
55  size_t output_height = (padded_input_height - 5 /* kernel size */ + 2 /* subsampling */) / 2;
56  do {
57    $for M in range(3, 3 + 2 * ROW_TILE):
58      if XNN_UNPREDICTABLE(padded_input_height < ${3 + M}) {
59        i${M} = zero;
60        $if M % 2 == 0 and M <= 2 * ROW_TILE + 1:
61          o${M / 2 - 1} = o${M / 2 - 2};
62      }
63
64    $for M in range(3 + 2 * ROW_TILE):
65      float vi${M}x0 = 0.0f;
66
67    $for M in range(3 + 2 * ROW_TILE):
68      float vi${M}x1 = 0.0f;
69
70    $for M in range(3 + 2 * ROW_TILE):
71      float vi${M}x2 = *i${M}++;
72
73    size_t w = input_width;
74    for (; w > 2 * sizeof(float); w -= 2 * sizeof(float)) {
75      $for M in range(3 + 2 * ROW_TILE):
76        const float vi${M}x3 = i${M}[0];
77
78      $for M in range(3 + 2 * ROW_TILE):
79        const float vi${M}x4 = i${M}[1];
80        i${M} += 2;
81
82      $for K in range(5):
83        $for M in range(ROW_TILE):
84          $if K == 0:
85            float vo${M}p0 = vbias + vi${2*M+K}x0 * vk${K}0;
86          $elif K < ACCUMULATORS:
87            float vo${M}p${K} = vi${2*M+K}x0 * vk${K}0;
88          $else:
89            vo${M}p${K % ACCUMULATORS} += vi${2*M+K}x0 * vk${K}0;
90
91      $for M in range(3 + 2 * ROW_TILE):
92        vi${M}x0 = vi${M}x2;
93
94      $for K in range(5):
95        $for M in range(ROW_TILE):
96          $if K+5 < ACCUMULATORS:
97            float vo${M}p${K+5} = vi${2*M+K}x1 * vk${K}1;
98          $else:
99            vo${M}p${(K+5) % ACCUMULATORS} += vi${2*M+K}x1 * vk${K}1;
100
101      $for M in range(3 + 2 * ROW_TILE):
102        vi${M}x1 = vi${M}x3;
103
104      $for K in range(5):
105        $for M in range(ROW_TILE):
106          vo${M}p${(K+10) % ACCUMULATORS} += vi${2*M+K}x2 * vk${K}2;
107
108      $for M in range(3 + 2 * ROW_TILE):
109        vi${M}x2 = vi${M}x4;
110
111      $for K in range(5):
112        $for M in range(ROW_TILE):
113          vo${M}p${(K+15) % ACCUMULATORS} += vi${2*M+K}x3 * vk${K}3;
114
115      $for K in range(5):
116        $for M in range(ROW_TILE):
117          vo${M}p${(K+20) % ACCUMULATORS} += vi${2*M+K}x4 * vk${K}4;
118
119      $if ACCUMULATORS > 1:
120        $ACC_SLICE = 1
121        $while ACC_SLICE < ACCUMULATORS:
122          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
123            $if A + ACC_SLICE < ACCUMULATORS:
124              $for M in range(ROW_TILE):
125                vo${M}p${A} += vo${M}p${A + ACC_SLICE};
126          $ACC_SLICE *= 2
127
128      $for M in range(ROW_TILE):
129        float vo${M} = math_max_f32(vo${M}p0, vmin);
130
131      $for M in range(ROW_TILE):
132        vo${M} = math_min_f32(vo${M}, vmax);
133
134      $for M in reversed(range(ROW_TILE)):
135        *o${M}++ = vo${M};
136    }
137    if XNN_LIKELY(w == 2 * sizeof(float)) {
138      $for M in range(3 + 2 * ROW_TILE):
139        const float vi${M}x3 = *i${M}++;
140
141      $for K in range(5):
142        $for M in range(ROW_TILE):
143          $if K == 0:
144            float vo${M}p0 = vbias + vi${2*M+K}x0 * vk${K}0;
145          $elif K < ACCUMULATORS:
146            float vo${M}p${K} = vi${2*M+K}x0 * vk${K}0;
147          $else:
148            vo${M}p${K % ACCUMULATORS} += vi${2*M+K}x0 * vk${K}0;
149
150      $for K in range(5):
151        $for M in range(ROW_TILE):
152          $if K+5 < ACCUMULATORS:
153            float vo${M}p${K+5} = vi${2*M+K}x1 * vk${K}1;
154          $else:
155            vo${M}p${(K+5) % ACCUMULATORS} += vi${2*M+K}x1 * vk${K}1;
156
157      $for K in range(5):
158        $for M in range(ROW_TILE):
159          vo${M}p${(K+10) % ACCUMULATORS} += vi${2*M+K}x2 * vk${K}2;
160
161      $for K in range(5):
162        $for M in range(ROW_TILE):
163          vo${M}p${(K+15) % ACCUMULATORS} += vi${2*M+K}x3 * vk${K}3;
164
165      $if ACCUMULATORS > 1:
166        $ACC_SLICE = 1
167        $while ACC_SLICE < ACCUMULATORS:
168          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
169            $if A + ACC_SLICE < ACCUMULATORS:
170              $for M in range(ROW_TILE):
171                vo${M}p${A} += vo${M}p${A + ACC_SLICE};
172          $ACC_SLICE *= 2
173
174      $for M in range(ROW_TILE):
175        float vo${M} = math_max_f32(vo${M}p0, vmin);
176
177      $for M in range(ROW_TILE):
178        vo${M} = math_min_f32(vo${M}, vmax);
179
180      $for M in reversed(range(ROW_TILE)):
181        *o${M}++ = vo${M};
182    } else {
183      $for K in range(5):
184        $for M in range(ROW_TILE):
185          $if K == 0:
186            float vo${M}p0 = vbias + vi${2*M+K}x0 * vk${K}0;
187          $elif K < ACCUMULATORS:
188            float vo${M}p${K} = vi${2*M+K}x0 * vk${K}0;
189          $else:
190            vo${M}p${K % ACCUMULATORS} += vi${2*M+K}x0 * vk${K}0;
191
192      $for K in range(5):
193        $for M in range(ROW_TILE):
194          $if K+5 < ACCUMULATORS:
195            float vo${M}p${K+5} = vi${2*M+K}x1 * vk${K}1;
196          $else:
197            vo${M}p${(K+5) % ACCUMULATORS} += vi${2*M+K}x1 * vk${K}1;
198
199      $for K in range(5):
200        $for M in range(ROW_TILE):
201          vo${M}p${(K+10) % ACCUMULATORS} += vi${2*M+K}x2 * vk${K}2;
202
203      $if ACCUMULATORS > 1:
204        $ACC_SLICE = 1
205        $while ACC_SLICE < ACCUMULATORS:
206          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
207            $if A + ACC_SLICE < ACCUMULATORS:
208              $for M in range(ROW_TILE):
209                vo${M}p${A} += vo${M}p${A + ACC_SLICE};
210          $ACC_SLICE *= 2
211
212      $for M in range(ROW_TILE):
213        float vo${M} = math_max_f32(vo${M}p0, vmin);
214
215      $for M in range(ROW_TILE):
216        vo${M} = math_min_f32(vo${M}, vmax);
217
218      $for M in reversed(range(ROW_TILE)):
219        *o${M}++ = vo${M};
220    }
221
222    $if ROW_TILE == 1:
223      i0 = (const float*) ((uintptr_t) i${2 * ROW_TILE} - input_width);
224    $else:
225      i0 = (const float*) ((uintptr_t) i${2 * ROW_TILE - 1});
226    i1 = (const float*) ((uintptr_t) i${2 * ROW_TILE});
227    i2 = (const float*) ((uintptr_t) i${2 * ROW_TILE + 1});
228    i3 = (const float*) ((uintptr_t) i${2 * ROW_TILE + 2});
229    $for M in range(4, 3 + 2 * ROW_TILE):
230      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
231
232    $if ROW_TILE > 1:
233      o0 = o${ROW_TILE - 1};
234      $for M in range(1, ROW_TILE):
235        o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
236
237    $if ROW_TILE > 1:
238      output_height = doz(output_height, ${ROW_TILE});
239      padded_input_height = doz(padded_input_height, ${ROW_TILE * 2});
240    $else:
241      output_height -= 1;
242      padded_input_height -= 2;
243  } while (output_height != 0);
244}
245