1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert PIXEL_TILE >= 1
7$assert PIXEL_TILE % 4 == 0
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
10$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/ibilinear.h>
16
17
18void xnn_f32_ibilinear_chw_ukernel__${"neonfma" if FMA else "neon"}_p${PIXEL_TILE}(
19    size_t output_pixels,
20    size_t channels,
21    const float**restrict input,
22    size_t input_offset,
23    const float*restrict weights,
24    float*restrict output,
25    size_t input_increment) XNN_DISABLE_TSAN
26{
27  assert(output_pixels != 0);
28  assert(channels != 0);
29  assert(input_increment % sizeof(float) == 0);
30
31  do {
32    const float** i = input;
33    const float* w = weights;
34    size_t p = output_pixels;
35    $if PIXEL_TILE > 4:
36      for (; p >= ${PIXEL_TILE}; p -= ${PIXEL_TILE}) {
37        $for P in range(PIXEL_TILE):
38          const float* itl${ABC[P]} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset);
39          const float* ibl${ABC[P]} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
40        i += 2 * ${PIXEL_TILE};
41
42        $for P in range(0, PIXEL_TILE, 4):
43          const float32x4x2_t vw${ABC[P:P+4]} = vld2q_f32(w + ${2 * P});
44        w += 2 * ${PIXEL_TILE};
45
46        $for P in range(0, PIXEL_TILE):
47          const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P});
48          const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P});
49
50        $for P in range(0, PIXEL_TILE, 4):
51          const float32x4_t valphah${ABC[P:P+4]} = vw${ABC[P:P+4]}.val[0];
52          const float32x4_t valphav${ABC[P:P+4]} = vw${ABC[P:P+4]}.val[1];
53
54        $for P in range(0, PIXEL_TILE, 2):
55          const float32x4_t vtltr${ABC[P:P+2]} = vcombine_f32(vtltr${ABC[P]}, vtltr${ABC[P+1]});
56          const float32x4_t vblbr${ABC[P:P+2]} = vcombine_f32(vblbr${ABC[P]}, vblbr${ABC[P+1]});
57
58        $for P in range(0, PIXEL_TILE, 2):
59          const float32x4_t vldrd${ABC[P:P+2]} = vsubq_f32(vblbr${ABC[P:P+2]}, vtltr${ABC[P:P+2]});
60
61        $for P in range(0, PIXEL_TILE, 4):
62          const float32x4x2_t vld_t${ABC[P:P+4]} = vuzpq_f32(vldrd${ABC[P:P+2]}, vldrd${ABC[P+2:P+4]});
63          const float32x4_t vld${ABC[P:P+4]} = vld_t${ABC[P:P+4]}.val[0];
64          const float32x4_t vrd${ABC[P:P+4]} = vld_t${ABC[P:P+4]}.val[1];
65
66        $for P in range(0, PIXEL_TILE, 4):
67          const float32x4x2_t vtl_t${ABC[P:P+4]} = vuzpq_f32(vtltr${ABC[P:P+2]}, vtltr${ABC[P+2:P+4]});
68          const float32x4_t vtl${ABC[P:P+4]} = vtl_t${ABC[P:P+4]}.val[0];
69          const float32x4_t vtr${ABC[P:P+4]} = vtl_t${ABC[P:P+4]}.val[1];
70
71        $for P in range(0, PIXEL_TILE, 4):
72          const float32x4_t vl${ABC[P:P+4]} = ${VMULADDQ_F32}(vtl${ABC[P:P+4]}, vld${ABC[P:P+4]}, valphav${ABC[P:P+4]});
73          const float32x4_t vr${ABC[P:P+4]} = ${VMULADDQ_F32}(vtr${ABC[P:P+4]}, vrd${ABC[P:P+4]}, valphav${ABC[P:P+4]});
74
75        $for P in range(0, PIXEL_TILE, 4):
76          const float32x4_t vd${ABC[P:P+4]} = vsubq_f32(vr${ABC[P:P+4]}, vl${ABC[P:P+4]});
77
78        $for P in range(0, PIXEL_TILE, 4):
79          const float32x4_t vo${ABC[P:P+4]} = ${VMULADDQ_F32}(vl${ABC[P:P+4]}, vd${ABC[P:P+4]}, valphah${ABC[P:P+4]});
80
81        $for P in range(0, PIXEL_TILE, 4):
82          vst1q_f32(output + ${P}, vo${ABC[P:P+4]});
83        output += ${PIXEL_TILE};
84      }
85
86    for (; p >= 4; p -= 4) {
87      $for P in range(4):
88        const float* itl${P} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset);
89        const float* ibl${P} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
90      i += 8;
91
92      const float32x4x2_t vw = vld2q_f32(w);
93      w += 8;
94
95      $for P in range(0, 4):
96        const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P});
97        const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P});
98
99      const float32x4_t valphah = vw.val[0];
100      const float32x4_t valphav = vw.val[1];
101
102      $for P in range(0, 4, 2):
103        const float32x4_t vtltr${ABC[P:P+2]} = vcombine_f32(vtltr${ABC[P]}, vtltr${ABC[P+1]});
104        const float32x4_t vblbr${ABC[P:P+2]} = vcombine_f32(vblbr${ABC[P]}, vblbr${ABC[P+1]});
105
106      $for P in range(0, 4, 2):
107        const float32x4_t vldrd${ABC[P:P+2]} = vsubq_f32(vblbr${ABC[P:P+2]}, vtltr${ABC[P:P+2]});
108
109      const float32x4x2_t vld_t = vuzpq_f32(vldrd01, vldrd23);
110      const float32x4_t vld = vld_t.val[0];
111      const float32x4_t vrd = vld_t.val[1];
112
113      const float32x4x2_t vtl_t = vuzpq_f32(vtltr01, vtltr23);
114      const float32x4_t vtl = vtl_t.val[0];
115      const float32x4_t vtr = vtl_t.val[1];
116
117      const float32x4_t vl = ${VMULADDQ_F32}(vtl, vld, valphav);
118      const float32x4_t vr = ${VMULADDQ_F32}(vtr, vrd, valphav);
119
120      const float32x4_t vd = vsubq_f32(vr, vl);
121      const float32x4_t vo = ${VMULADDQ_F32}(vl, vd, valphah);
122
123      vst1q_f32(output, vo);
124      output += 4;
125    }
126
127    if XNN_UNLIKELY(p != 0) {
128      if (p & 2) {
129        const float32x2x2_t vw = vld2_f32(w);
130        w += 4;
131
132        const float32x2_t valphah = vw.val[0];
133        const float32x2_t valphav = vw.val[1];
134
135        $for P in range(2):
136          const float* itl${P} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset);
137          const float* ibl${P} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
138        i += 4;
139
140        $for P in range(0, 2):
141          const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P});
142          const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P});
143
144        $for P in range(0, 2):
145          const float32x2_t vldrd${ABC[P]} = vsub_f32(vblbr${ABC[P]}, vtltr${ABC[P]});
146
147        const float32x2x2_t vld_t = vuzp_f32(vldrd0, vldrd1);
148        const float32x2_t vld = vld_t.val[0];
149        const float32x2_t vrd = vld_t.val[1];
150
151        const float32x2x2_t vtl_t = vuzp_f32(vtltr0, vtltr1);
152        const float32x2_t vtl = vtl_t.val[0];
153        const float32x2_t vtr = vtl_t.val[1];
154
155        const float32x2_t vl = ${VMULADD_F32}(vtl, vld, valphav);
156        const float32x2_t vr = ${VMULADD_F32}(vtr, vrd, valphav);
157
158        const float32x2_t vd = vsub_f32(vr, vl);
159        const float32x2_t vo = ${VMULADD_F32}(vl, vd, valphah);
160
161        vst1_f32(output, vo);
162        output += 2;
163      }
164
165      if (p & 1) {
166        // We are computing the following formula:
167        //   result = (1 - alpha_h) * (1 - alpha_v) * top_left +
168        //                 alpha_h  * (1 - alpha_v) * top_right +
169        //            (1 - alpha_h) *      alpha_v  * bottom_left +
170        //                 alpha_h  *      alpha_v  * bottom_right.
171        //
172        // Rearranging gives
173        //   result =    left + alpha_h * (right        - left),
174        // where
175        //   left =  top_left + alpha_v * (bottom_left  - top_left),
176        //  right = top_right + alpha_v * (bottom_right - top_right).
177
178        const float alphah = *w;
179        const float32x2_t valphav = vld1_dup_f32(w + 1);
180        w += 2;
181
182        const float* itl = (const float*) ((uintptr_t) i[0] + input_offset);
183        const float* ibl = (const float*) ((uintptr_t) i[1] + input_offset);
184        i += 2;
185
186        const float32x2_t vtltr = vld1_f32(itl);
187        const float32x2_t vblbr = vld1_f32(ibl);
188
189        // Compute at once
190        //    left_diff = bottom_left  - top_left
191        //   right_diff = bottom_right - top_right
192        const float32x2_t vldrd = vsub_f32(vblbr, vtltr);
193        const float32x2_t vlr = ${VMULADD_F32}(vtltr, vldrd, valphav);
194
195        // Extract them and compute the result.
196        const float l = vget_lane_f32(vlr, 0);
197        const float r = vget_lane_f32(vlr, 1);
198
199        *output++ = l + alphah * (r - l);
200      }
201    }
202
203    input_offset += input_increment;
204  } while (--channels != 0);
205}
206