1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert PIXEL_TILE == 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/common.h>
16#include <xnnpack/ibilinear.h>
17
18
19void xnn_f32_ibilinear_ukernel__${"neonfma" if FMA else "neon"}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}(
20    size_t output_pixels,
21    size_t channels,
22    const float**restrict input,
23    size_t input_offset,
24    const float*restrict weights,
25    float*restrict output,
26    size_t output_increment) XNN_DISABLE_TSAN
27{
28  assert(output_pixels != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  do {
33    const float* i0 = (const float*) ((uintptr_t) input[0] + input_offset);
34    const float* i1 = (const float*) ((uintptr_t) input[1] + input_offset);
35    const float* i2 = (const float*) ((uintptr_t) input[2] + input_offset);
36    const float* i3 = (const float*) ((uintptr_t) input[3] + input_offset);
37    input += 4;
38
39    const float32x2_t valphahv = vld1_f32(weights); weights += 2;
40    $if FMA:
41      #if XNN_ARCH_ARM
42        const float32x4_t valphah = vdupq_lane_f32(valphahv, 0);
43        const float32x4_t valphav = vdupq_lane_f32(valphahv, 1);
44      #endif
45
46    size_t c = channels;
47    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
48      $for C in range(0, CHANNEL_TILE, 4):
49        const float32x4_t vtl${ABC[C:C+4]} = vld1q_f32(i0); i0 += 4;
50        const float32x4_t vtr${ABC[C:C+4]} = vld1q_f32(i1); i1 += 4;
51        const float32x4_t vbl${ABC[C:C+4]} = vld1q_f32(i2); i2 += 4;
52        const float32x4_t vbr${ABC[C:C+4]} = vld1q_f32(i3); i3 += 4;
53
54      $for C in range(0, CHANNEL_TILE, 4):
55        const float32x4_t vtd${ABC[C:C+4]} = vsubq_f32(vtr${ABC[C:C+4]}, vtl${ABC[C:C+4]});
56        const float32x4_t vbd${ABC[C:C+4]} = vsubq_f32(vbr${ABC[C:C+4]}, vbl${ABC[C:C+4]});
57
58      $if FMA:
59        #if XNN_ARCH_ARM
60        $for C in range(0, CHANNEL_TILE, 4):
61          const float32x4_t vt${ABC[C:C+4]} = vfmaq_f32(vtl${ABC[C:C+4]}, vtd${ABC[C:C+4]}, valphah);
62          const float32x4_t vb${ABC[C:C+4]} = vfmaq_f32(vbl${ABC[C:C+4]}, vbd${ABC[C:C+4]}, valphah);
63        #else
64        $for C in range(0, CHANNEL_TILE, 4):
65          const float32x4_t vt${ABC[C:C+4]} = vfmaq_lane_f32(vtl${ABC[C:C+4]}, vtd${ABC[C:C+4]}, valphahv, 0);
66          const float32x4_t vb${ABC[C:C+4]} = vfmaq_lane_f32(vbl${ABC[C:C+4]}, vbd${ABC[C:C+4]}, valphahv, 0);
67        #endif
68      $else:
69        $for C in range(0, CHANNEL_TILE, 4):
70          const float32x4_t vt${ABC[C:C+4]} = vmlaq_lane_f32(vtl${ABC[C:C+4]}, vtd${ABC[C:C+4]}, valphahv, 0);
71          const float32x4_t vb${ABC[C:C+4]} = vmlaq_lane_f32(vbl${ABC[C:C+4]}, vbd${ABC[C:C+4]}, valphahv, 0);
72
73      $for C in range(0, CHANNEL_TILE, 4):
74        const float32x4_t vd${ABC[C:C+4]} = vsubq_f32(vb${ABC[C:C+4]}, vt${ABC[C:C+4]});
75
76      $if FMA:
77        #if XNN_ARCH_ARM
78        $for C in range(0, CHANNEL_TILE, 4):
79          const float32x4_t vo${ABC[C:C+4]} = vfmaq_f32(vt${ABC[C:C+4]}, vd${ABC[C:C+4]}, valphav);
80        #else
81        $for C in range(0, CHANNEL_TILE, 4):
82          const float32x4_t vo${ABC[C:C+4]} = vfmaq_lane_f32(vt${ABC[C:C+4]}, vd${ABC[C:C+4]}, valphahv, 1);
83        #endif
84      $else:
85        $for C in range(0, CHANNEL_TILE, 4):
86          const float32x4_t vo${ABC[C:C+4]} = vmlaq_lane_f32(vt${ABC[C:C+4]}, vd${ABC[C:C+4]}, valphahv, 1);
87
88      $for C in range(0, CHANNEL_TILE, 4):
89        vst1q_f32(output, vo${ABC[C:C+4]}); output += 4;
90    }
91    $if CHANNEL_TILE > 4:
92      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
93        const float32x4_t vtl0123 = vld1q_f32(i0); i0 += 4;
94        const float32x4_t vtr0123 = vld1q_f32(i1); i1 += 4;
95        const float32x4_t vbl0123 = vld1q_f32(i2); i2 += 4;
96        const float32x4_t vbr0123 = vld1q_f32(i3); i3 += 4;
97
98        const float32x4_t vtd0123 = vsubq_f32(vtr0123, vtl0123);
99        const float32x4_t vbd0123 = vsubq_f32(vbr0123, vbl0123);
100
101        $if FMA:
102          #if XNN_ARCH_ARM
103          const float32x4_t vt0123 = vfmaq_f32(vtl0123, vtd0123, valphah);
104          const float32x4_t vb0123 = vfmaq_f32(vbl0123, vbd0123, valphah);
105          #else
106          const float32x4_t vt0123 = vfmaq_lane_f32(vtl0123, vtd0123, valphahv, 0);
107          const float32x4_t vb0123 = vfmaq_lane_f32(vbl0123, vbd0123, valphahv, 0);
108          #endif
109        $else:
110          const float32x4_t vt0123 = vmlaq_lane_f32(vtl0123, vtd0123, valphahv, 0);
111          const float32x4_t vb0123 = vmlaq_lane_f32(vbl0123, vbd0123, valphahv, 0);
112
113        const float32x4_t vd0123 = vsubq_f32(vb0123, vt0123);
114
115        $if FMA:
116          #if XNN_ARCH_ARM
117          const float32x4_t vo0123 = vfmaq_f32(vt0123, vd0123, valphav);
118          #else
119          const float32x4_t vo0123 = vfmaq_lane_f32(vt0123, vd0123, valphahv, 1);
120          #endif
121        $else:
122          const float32x4_t vo0123 = vmlaq_lane_f32(vt0123, vd0123, valphahv, 1);
123
124        vst1q_f32(output, vo0123);
125        output += 4;
126      }
127    if XNN_UNLIKELY(c != 0) {
128      const float32x4_t vtl0123 = vld1q_f32(i0);
129      const float32x4_t vtr0123 = vld1q_f32(i1);
130      const float32x4_t vbl0123 = vld1q_f32(i2);
131      const float32x4_t vbr0123 = vld1q_f32(i3);
132
133      const float32x4_t vtd0123 = vsubq_f32(vtr0123, vtl0123);
134      const float32x4_t vbd0123 = vsubq_f32(vbr0123, vbl0123);
135
136        $if FMA:
137          #if XNN_ARCH_ARM
138          const float32x4_t vt0123 = vfmaq_f32(vtl0123, vtd0123, valphah);
139          const float32x4_t vb0123 = vfmaq_f32(vbl0123, vbd0123, valphah);
140          #else
141          const float32x4_t vt0123 = vfmaq_lane_f32(vtl0123, vtd0123, valphahv, 0);
142          const float32x4_t vb0123 = vfmaq_lane_f32(vbl0123, vbd0123, valphahv, 0);
143          #endif
144        $else:
145          const float32x4_t vt0123 = vmlaq_lane_f32(vtl0123, vtd0123, valphahv, 0);
146          const float32x4_t vb0123 = vmlaq_lane_f32(vbl0123, vbd0123, valphahv, 0);
147
148      const float32x4_t vd0123 = vsubq_f32(vb0123, vt0123);
149
150      $if FMA:
151        #if XNN_ARCH_ARM
152        float32x4_t vo0123 = vfmaq_f32(vt0123, vd0123, valphav);
153        #else
154        float32x4_t vo0123 = vfmaq_lane_f32(vt0123, vd0123, valphahv, 1);
155        #endif
156      $else:
157        const float32x4_t vo0123 = vmlaq_lane_f32(vt0123, vd0123, valphahv, 1);
158
159      float32x2_t vo01 = vget_low_f32(vo0123);
160      if (c & (2 * sizeof(float))) {
161        vst1_f32(output, vo01); output += 2;
162        vo01 = vget_high_f32(vo0123);
163      }
164      if (c & (1 * sizeof(float))) {
165        vst1_lane_f32(output, vo01, 0); output += 1;
166      }
167    }
168
169    output = (float*) ((uintptr_t) output + output_increment);
170  } while (--output_pixels != 0);
171}
172