1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8$VMULADDQ_LANE_F32 = "vfmaq_lane_f32" if FMA else "vmlaq_lane_f32"
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/dwconv.h>
14#include <xnnpack/math.h>
15
16
17void xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__${"neonfma" if FMA else "neon"}_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
18    size_t input_height,
19    size_t input_width,
20    const float* input,
21    const float* weights,
22    const float* zero,
23    float* output,
24    uint32_t padding_top,
25    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(input_height != 0);
28  assert(input_width != 0);
29  assert(input_width % sizeof(float) == 0);
30  assert(padding_top >= 0);
31  assert(padding_top <= 1);
32
33  const uint32x4_t vmask_even = vld1q_u32(params->neon.mask_even);
34  const uint32x4_t vmask_odd  = vld1q_u32(params->neon.mask_odd);
35  const float32x4_t vmax = vld1q_dup_f32(&params->neon.max);
36  const float32x4_t vmin = vld1q_dup_f32(&params->neon.min);
37
38  const float32x4_t vw0123 = vld1q_f32(weights);
39  const float32x4_t vw4567 = vld1q_f32(weights + 4);
40  const float32x2_t vw89 = vld1_f32(weights + 8);
41
42  const size_t input_decrement = round_down_po2(input_width, 4 /* SIMD output width */ * 2 /* subsampling */ * sizeof(float));
43  $if ROW_TILE > 1:
44    const size_t output_width = round_down_po2((input_width + (2 /* padding */ - 3 /* kernel size */ + 2 /* subsampling */) * sizeof(float)) / 2, sizeof(float));
45
46  const float* i0 = (const float*) ((uintptr_t) input - ((-padding_top) & input_width));
47  const float* i1 = (const float*) ((uintptr_t) i0 + input_width);
48  if XNN_UNPREDICTABLE(padding_top != 0) {
49    i0 = zero;
50  }
51  $for M in range(2, 1 + 2 * ROW_TILE):
52    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
53
54  float* o0 = output;
55  $for M in range(1, ROW_TILE):
56    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
57
58  size_t padded_input_height = input_height + padding_top + 1 /* padding bottom */;
59  size_t output_height = (padded_input_height - 3 /* kernel size */ + 2 /* subsampling */) / 2;
60  do {
61    $for M in range(2, 1 + 2 * ROW_TILE):
62      if XNN_UNPREDICTABLE(padded_input_height < ${2 + M}) {
63        i${M} = zero;
64        $if M % 2 == 1:
65          o${(M - 1) / 2} = o${(M - 1) / 2 - 1};
66      }
67
68    $for M in range(1 + 2 * ROW_TILE):
69      float32x4_t vi${M}x1357 = vmovq_n_f32(0.0f);
70
71    size_t w = input_width;
72    for (; w >= 8 * sizeof(float); w -= 8 * sizeof(float)) {
73      $for M in range(ROW_TILE):
74        float32x4_t vo${M}p0 = vdupq_lane_f32(vget_low_f32(vw0123), 0);
75
76      $for M in range(1 + 2 * ROW_TILE):
77        const float32x4x2_t vi${M}x8ACE9BDF = vld2q_f32(i${M}); i${M} += 8;
78
79      $for M in range(ROW_TILE):
80        $if ACCUMULATORS > 1:
81          float32x4_t vo${M}p1 = vmulq_lane_f32(vi${2*M}x8ACE9BDF.val[0], vget_high_f32(vw0123), 0);
82        $else:
83          vo${M}p0 = ${VMULADDQ_LANE_F32}(vo${M}p0, vi${2*M}x8ACE9BDF.val[0], vget_high_f32(vw0123), 0);
84
85      $for M in range(ROW_TILE):
86        $if ACCUMULATORS > 2:
87          float32x4_t vo${M}p2 = vmulq_lane_f32(vi${2*M+1}x8ACE9BDF.val[0], vget_low_f32(vw4567), 1);
88        $else:
89          vo${M}p0 = ${VMULADDQ_LANE_F32}(vo${M}p0, vi${2*M+1}x8ACE9BDF.val[0], vget_low_f32(vw4567), 1);
90
91      $for M in range(ROW_TILE):
92        $if ACCUMULATORS > 3:
93          float32x4_t vo${M}p3 = vmulq_lane_f32(vi${2*M+2}x8ACE9BDF.val[0], vw89, 0);
94        $else:
95          vo${M}p${4 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${4 % ACCUMULATORS}, vi${2*M+2}x8ACE9BDF.val[0], vw89, 0);
96
97      $for M in range(1 + 2 * ROW_TILE):
98        const float32x4_t vi${M}x7BDF = vextq_f32(vi${M}x1357, vi${M}x8ACE9BDF.val[1], 3);
99        vi${M}x1357 = vi${M}x8ACE9BDF.val[1];
100
101      $for M in range(ROW_TILE):
102        $if ACCUMULATORS > 4:
103          float32x4_t vo${M}p4 = vmulq_lane_f32(vi${2*M}x7BDF, vget_low_f32(vw0123), 1);
104        $else:
105          vo${M}p${5 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${5 % ACCUMULATORS}, vi${2*M}x7BDF, vget_low_f32(vw0123), 1);
106
107      $for M in range(ROW_TILE):
108        $if ACCUMULATORS > 5:
109          float32x4_t vo${M}p5 = vmulq_lane_f32(vi${2*M+1}x7BDF, vget_low_f32(vw4567), 0);
110        $else:
111          vo${M}p${6 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${6 % ACCUMULATORS}, vi${2*M+1}x7BDF, vget_low_f32(vw4567), 0);
112
113      $for M in range(ROW_TILE):
114        $if ACCUMULATORS > 6:
115          float32x4_t vo${M}p6 = vmulq_lane_f32(vi${2*M+2}x7BDF, vget_low_f32(vw4567), 1);
116        $else:
117          vo${M}p${7 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${7 % ACCUMULATORS}, vi${2*M+2}x7BDF, vget_high_f32(vw4567), 1);
118
119      $for M in range(ROW_TILE):
120        vo${M}p${8 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${8 % ACCUMULATORS}, vi${2*M}x8ACE9BDF.val[1], vget_high_f32(vw0123), 1);
121
122      $for M in range(ROW_TILE):
123        vo${M}p${9 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${9 % ACCUMULATORS}, vi${2*M+1}x8ACE9BDF.val[1], vget_high_f32(vw4567), 0);
124
125      $for M in range(ROW_TILE):
126        vo${M}p${10 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${10 % ACCUMULATORS}, vi${2*M+2}x8ACE9BDF.val[1], vw89, 1);
127
128      $if ACCUMULATORS > 1:
129        $ACC_SLICE = 1
130        $while ACC_SLICE < ACCUMULATORS:
131          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
132            $if A + ACC_SLICE < ACCUMULATORS:
133              $for M in range(ROW_TILE):
134                vo${M}p${A} = vaddq_f32(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
135          $ACC_SLICE *= 2
136
137      $for M in range(ROW_TILE):
138        float32x4_t vo${M} = vmaxq_f32(vo${M}p0, vmin);
139
140      $for M in range(ROW_TILE):
141        vo${M} = vminq_f32(vo${M}, vmax);
142
143      $for M in reversed(range(ROW_TILE)):
144        vst1q_f32(o${M}, vo${M}); o${M} += 4;
145    }
146    // Last block has 0-7 pixels to process.
147    assert(w < 8 * sizeof(float));
148    if XNN_LIKELY(w != 0) {
149      $for M in range(ROW_TILE):
150        float32x4_t vo${M}p0 = vdupq_lane_f32(vget_low_f32(vw0123), 0);
151
152      $for M in range(1 + 2 * ROW_TILE):
153        const float32x4x2_t vi${M}x8ACE9BDF = vld2q_f32(i${M});
154
155      $for M in range(1 + 2 * ROW_TILE):
156        const float32x4_t vi${M}x8ACE = vreinterpretq_f32_u32(vandq_u32(vmask_even, vreinterpretq_u32_f32(vi${M}x8ACE9BDF.val[0])));
157        const float32x4_t vi${M}x9BDF = vreinterpretq_f32_u32(vandq_u32(vmask_odd,  vreinterpretq_u32_f32(vi${M}x8ACE9BDF.val[1])));
158
159      $for M in range(ROW_TILE):
160        $if ACCUMULATORS > 1:
161          float32x4_t vo${M}p1 = vmulq_lane_f32(vi${2*M}x8ACE, vget_high_f32(vw0123), 0);
162        $else:
163          vo${M}p0 = ${VMULADDQ_LANE_F32}(vo${M}p0, vi${2*M}x8ACE, vget_high_f32(vw0123), 0);
164
165      $for M in range(ROW_TILE):
166        $if ACCUMULATORS > 2:
167          float32x4_t vo${M}p2 = vmulq_lane_f32(vi${2*M+1}x8ACE, vget_low_f32(vw4567), 1);
168        $else:
169          vo${M}p0 = ${VMULADDQ_LANE_F32}(vo${M}p0, vi${2*M+1}x8ACE, vget_low_f32(vw4567), 1);
170
171      $for M in range(ROW_TILE):
172        $if ACCUMULATORS > 3:
173          float32x4_t vo${M}p3 = vmulq_lane_f32(vi${2*M+2}x8ACE, vw89, 0);
174        $else:
175          vo${M}p${4 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${4 % ACCUMULATORS}, vi${2*M+2}x8ACE, vw89, 0);
176
177      $for M in range(1 + 2 * ROW_TILE):
178        const float32x4_t vi${M}x7BDF = vextq_f32(vi${M}x1357, vi${M}x9BDF, 3);
179
180      $for M in range(ROW_TILE):
181        $if ACCUMULATORS > 4:
182          float32x4_t vo${M}p4 = vmulq_lane_f32(vi${2*M}x7BDF, vget_low_f32(vw0123), 1);
183        $else:
184          vo${M}p${5 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${5 % ACCUMULATORS}, vi${2*M}x7BDF, vget_low_f32(vw0123), 1);
185
186      $for M in range(ROW_TILE):
187        $if ACCUMULATORS > 5:
188          float32x4_t vo${M}p5 = vmulq_lane_f32(vi${2*M+1}x7BDF, vget_low_f32(vw4567), 0);
189        $else:
190          vo${M}p${6 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${6 % ACCUMULATORS}, vi${2*M+1}x7BDF, vget_low_f32(vw4567), 0);
191
192      $for M in range(ROW_TILE):
193        $if ACCUMULATORS > 6:
194          float32x4_t vo${M}p6 = vmulq_lane_f32(vi${2*M+2}x7BDF, vget_low_f32(vw4567), 1);
195        $else:
196          vo${M}p${7 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${7 % ACCUMULATORS}, vi${2*M+2}x7BDF, vget_high_f32(vw4567), 1);
197
198      $for M in range(ROW_TILE):
199        vo${M}p${8 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${8 % ACCUMULATORS}, vi${2*M}x9BDF, vget_high_f32(vw0123), 1);
200
201      $for M in range(ROW_TILE):
202        vo${M}p${9 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${9 % ACCUMULATORS}, vi${2*M+1}x9BDF, vget_high_f32(vw4567), 0);
203
204      $for M in range(ROW_TILE):
205        vo${M}p${10 % ACCUMULATORS} = ${VMULADDQ_LANE_F32}(vo${M}p${10 % ACCUMULATORS}, vi${2*M+2}x9BDF, vw89, 1);
206
207      $if ACCUMULATORS > 1:
208        $ACC_SLICE = 1
209        $while ACC_SLICE < ACCUMULATORS:
210          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
211            $if A + ACC_SLICE < ACCUMULATORS:
212              $for M in range(ROW_TILE):
213                vo${M}p${A} = vaddq_f32(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
214          $ACC_SLICE *= 2
215
216      $for M in range(ROW_TILE):
217        float32x4_t vo${M} = vmaxq_f32(vo${M}p0, vmin);
218
219      $for M in range(ROW_TILE):
220        vo${M} = vminq_f32(vo${M}, vmax);
221
222      w += 1 * sizeof(float);
223      if (w & (8 * sizeof(float))) {
224        $for M in reversed(range(ROW_TILE)):
225          vst1q_f32(o${M}, vo${M}); o${M} += 4;
226      } else {
227        $for M in range(ROW_TILE):
228          float32x2_t vo${M}_lo = vget_low_f32(vo${M});
229        if (w & (4 * sizeof(float))) {
230          $for M in reversed(range(ROW_TILE)):
231            vst1_f32(o${M}, vo${M}_lo); o${M} += 2;
232
233          $for M in range(ROW_TILE):
234            vo${M}_lo = vget_high_f32(vo${M});
235        }
236        if (w & (2 * sizeof(float))) {
237          $for M in reversed(range(ROW_TILE)):
238            vst1_lane_f32(o${M}, vo${M}_lo, 0); o${M} += 1;
239        }
240      }
241    }
242
243    i0 = (const float*) ((uintptr_t) i${2 * ROW_TILE} - input_decrement);
244    $for M in range(1, 1 + 2 * ROW_TILE):
245      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
246
247    $if ROW_TILE > 1:
248      o0 = o${ROW_TILE - 1};
249      $for M in range(1, ROW_TILE):
250        o${M} = (float*) ((uintptr_t) o${M-1} + output_width);
251
252    $if ROW_TILE > 1:
253      output_height = doz(output_height, ${ROW_TILE});
254      padded_input_height = doz(padded_input_height, ${ROW_TILE * 2});
255    $else:
256      output_height -= 1;
257      padded_input_height -= 2;
258  } while (output_height != 0);
259}
260