1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert PIXEL_TILE >= 1 7$assert PIXEL_TILE % 4 == 0 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32" 10$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32" 11#include <assert.h> 12 13#include <arm_neon.h> 14 15#include <xnnpack/ibilinear.h> 16 17 18void xnn_f32_ibilinear_chw_ukernel__${"neonfma" if FMA else "neon"}_p${PIXEL_TILE}( 19 size_t output_pixels, 20 size_t channels, 21 const float**restrict input, 22 size_t input_offset, 23 const float*restrict weights, 24 float*restrict output, 25 size_t input_increment) XNN_DISABLE_TSAN 26{ 27 assert(output_pixels != 0); 28 assert(channels != 0); 29 assert(input_increment % sizeof(float) == 0); 30 31 do { 32 const float** i = input; 33 const float* w = weights; 34 size_t p = output_pixels; 35 $if PIXEL_TILE > 4: 36 for (; p >= ${PIXEL_TILE}; p -= ${PIXEL_TILE}) { 37 $for P in range(PIXEL_TILE): 38 const float* itl${ABC[P]} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset); 39 const float* ibl${ABC[P]} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 40 i += 2 * ${PIXEL_TILE}; 41 42 $for P in range(0, PIXEL_TILE, 4): 43 const float32x4x2_t vw${ABC[P:P+4]} = vld2q_f32(w + ${2 * P}); 44 w += 2 * ${PIXEL_TILE}; 45 46 $for P in range(0, PIXEL_TILE): 47 const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P}); 48 const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P}); 49 50 $for P in range(0, PIXEL_TILE, 4): 51 const float32x4_t valphah${ABC[P:P+4]} = vw${ABC[P:P+4]}.val[0]; 52 const float32x4_t valphav${ABC[P:P+4]} = vw${ABC[P:P+4]}.val[1]; 53 54 $for P in range(0, PIXEL_TILE, 2): 55 const float32x4_t vtltr${ABC[P:P+2]} = vcombine_f32(vtltr${ABC[P]}, vtltr${ABC[P+1]}); 56 const float32x4_t vblbr${ABC[P:P+2]} = vcombine_f32(vblbr${ABC[P]}, vblbr${ABC[P+1]}); 57 58 $for P in range(0, PIXEL_TILE, 2): 59 const float32x4_t vldrd${ABC[P:P+2]} = vsubq_f32(vblbr${ABC[P:P+2]}, vtltr${ABC[P:P+2]}); 60 61 $for P in range(0, PIXEL_TILE, 4): 62 const float32x4x2_t vld_t${ABC[P:P+4]} = vuzpq_f32(vldrd${ABC[P:P+2]}, vldrd${ABC[P+2:P+4]}); 63 const float32x4_t vld${ABC[P:P+4]} = vld_t${ABC[P:P+4]}.val[0]; 64 const float32x4_t vrd${ABC[P:P+4]} = vld_t${ABC[P:P+4]}.val[1]; 65 66 $for P in range(0, PIXEL_TILE, 4): 67 const float32x4x2_t vtl_t${ABC[P:P+4]} = vuzpq_f32(vtltr${ABC[P:P+2]}, vtltr${ABC[P+2:P+4]}); 68 const float32x4_t vtl${ABC[P:P+4]} = vtl_t${ABC[P:P+4]}.val[0]; 69 const float32x4_t vtr${ABC[P:P+4]} = vtl_t${ABC[P:P+4]}.val[1]; 70 71 $for P in range(0, PIXEL_TILE, 4): 72 const float32x4_t vl${ABC[P:P+4]} = ${VMULADDQ_F32}(vtl${ABC[P:P+4]}, vld${ABC[P:P+4]}, valphav${ABC[P:P+4]}); 73 const float32x4_t vr${ABC[P:P+4]} = ${VMULADDQ_F32}(vtr${ABC[P:P+4]}, vrd${ABC[P:P+4]}, valphav${ABC[P:P+4]}); 74 75 $for P in range(0, PIXEL_TILE, 4): 76 const float32x4_t vd${ABC[P:P+4]} = vsubq_f32(vr${ABC[P:P+4]}, vl${ABC[P:P+4]}); 77 78 $for P in range(0, PIXEL_TILE, 4): 79 const float32x4_t vo${ABC[P:P+4]} = ${VMULADDQ_F32}(vl${ABC[P:P+4]}, vd${ABC[P:P+4]}, valphah${ABC[P:P+4]}); 80 81 $for P in range(0, PIXEL_TILE, 4): 82 vst1q_f32(output + ${P}, vo${ABC[P:P+4]}); 83 output += ${PIXEL_TILE}; 84 } 85 86 for (; p >= 4; p -= 4) { 87 $for P in range(4): 88 const float* itl${P} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset); 89 const float* ibl${P} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 90 i += 8; 91 92 const float32x4x2_t vw = vld2q_f32(w); 93 w += 8; 94 95 $for P in range(0, 4): 96 const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P}); 97 const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P}); 98 99 const float32x4_t valphah = vw.val[0]; 100 const float32x4_t valphav = vw.val[1]; 101 102 $for P in range(0, 4, 2): 103 const float32x4_t vtltr${ABC[P:P+2]} = vcombine_f32(vtltr${ABC[P]}, vtltr${ABC[P+1]}); 104 const float32x4_t vblbr${ABC[P:P+2]} = vcombine_f32(vblbr${ABC[P]}, vblbr${ABC[P+1]}); 105 106 $for P in range(0, 4, 2): 107 const float32x4_t vldrd${ABC[P:P+2]} = vsubq_f32(vblbr${ABC[P:P+2]}, vtltr${ABC[P:P+2]}); 108 109 const float32x4x2_t vld_t = vuzpq_f32(vldrd01, vldrd23); 110 const float32x4_t vld = vld_t.val[0]; 111 const float32x4_t vrd = vld_t.val[1]; 112 113 const float32x4x2_t vtl_t = vuzpq_f32(vtltr01, vtltr23); 114 const float32x4_t vtl = vtl_t.val[0]; 115 const float32x4_t vtr = vtl_t.val[1]; 116 117 const float32x4_t vl = ${VMULADDQ_F32}(vtl, vld, valphav); 118 const float32x4_t vr = ${VMULADDQ_F32}(vtr, vrd, valphav); 119 120 const float32x4_t vd = vsubq_f32(vr, vl); 121 const float32x4_t vo = ${VMULADDQ_F32}(vl, vd, valphah); 122 123 vst1q_f32(output, vo); 124 output += 4; 125 } 126 127 if XNN_UNLIKELY(p != 0) { 128 if (p & 2) { 129 const float32x2x2_t vw = vld2_f32(w); 130 w += 4; 131 132 const float32x2_t valphah = vw.val[0]; 133 const float32x2_t valphav = vw.val[1]; 134 135 $for P in range(2): 136 const float* itl${P} = (const float*) ((uintptr_t) i[${2 * P}] + input_offset); 137 const float* ibl${P} = (const float*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 138 i += 4; 139 140 $for P in range(0, 2): 141 const float32x2_t vtltr${ABC[P]} = vld1_f32(itl${P}); 142 const float32x2_t vblbr${ABC[P]} = vld1_f32(ibl${P}); 143 144 $for P in range(0, 2): 145 const float32x2_t vldrd${ABC[P]} = vsub_f32(vblbr${ABC[P]}, vtltr${ABC[P]}); 146 147 const float32x2x2_t vld_t = vuzp_f32(vldrd0, vldrd1); 148 const float32x2_t vld = vld_t.val[0]; 149 const float32x2_t vrd = vld_t.val[1]; 150 151 const float32x2x2_t vtl_t = vuzp_f32(vtltr0, vtltr1); 152 const float32x2_t vtl = vtl_t.val[0]; 153 const float32x2_t vtr = vtl_t.val[1]; 154 155 const float32x2_t vl = ${VMULADD_F32}(vtl, vld, valphav); 156 const float32x2_t vr = ${VMULADD_F32}(vtr, vrd, valphav); 157 158 const float32x2_t vd = vsub_f32(vr, vl); 159 const float32x2_t vo = ${VMULADD_F32}(vl, vd, valphah); 160 161 vst1_f32(output, vo); 162 output += 2; 163 } 164 165 if (p & 1) { 166 // We are computing the following formula: 167 // result = (1 - alpha_h) * (1 - alpha_v) * top_left + 168 // alpha_h * (1 - alpha_v) * top_right + 169 // (1 - alpha_h) * alpha_v * bottom_left + 170 // alpha_h * alpha_v * bottom_right. 171 // 172 // Rearranging gives 173 // result = left + alpha_h * (right - left), 174 // where 175 // left = top_left + alpha_v * (bottom_left - top_left), 176 // right = top_right + alpha_v * (bottom_right - top_right). 177 178 const float alphah = *w; 179 const float32x2_t valphav = vld1_dup_f32(w + 1); 180 w += 2; 181 182 const float* itl = (const float*) ((uintptr_t) i[0] + input_offset); 183 const float* ibl = (const float*) ((uintptr_t) i[1] + input_offset); 184 i += 2; 185 186 const float32x2_t vtltr = vld1_f32(itl); 187 const float32x2_t vblbr = vld1_f32(ibl); 188 189 // Compute at once 190 // left_diff = bottom_left - top_left 191 // right_diff = bottom_right - top_right 192 const float32x2_t vldrd = vsub_f32(vblbr, vtltr); 193 const float32x2_t vlr = ${VMULADD_F32}(vtltr, vldrd, valphav); 194 195 // Extract them and compute the result. 196 const float l = vget_lane_f32(vlr, 0); 197 const float r = vget_lane_f32(vlr, 1); 198 199 *output++ = l + alphah * (r - l); 200 } 201 } 202 203 input_offset += input_increment; 204 } while (--channels != 0); 205} 206