1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$SSE_HEADER = {2: "emmintrin.h", 3: "tmmintrin.h", 4: "smmintrin.h"}[SSE] 7$assert CHANNEL_TILE % 8 == 0 8$assert CHANNEL_TILE >= 8 9$assert ROW_TILE >= 2 10$assert ACCUMULATORS >= 1 11$assert ROW_TILE >= ACCUMULATORS * 2 12$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 13#include <assert.h> 14 15#include <${SSE_HEADER}> 16 17#include <xnnpack/gavgpool.h> 18 19 20$ISA = {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE] 21void xnn_qs8_gavgpool_minmax_ukernel_${ROW_TILE}x__${ISA}_c${CHANNEL_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 22 size_t rows, 23 size_t channels, 24 const int8_t* input, 25 size_t input_stride, 26 const int8_t* zero, 27 int8_t* output, 28 const union xnn_qs8_avgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN 29{ 30 assert(rows != 0); 31 assert(rows <= ${ROW_TILE}); 32 assert(channels != 0); 33 34 const int8_t* i0 = input; 35 $for M in range(1, ROW_TILE): 36 const int8_t* i${M} = (const int8_t*) ((uintptr_t) i${M-1} + input_stride); 37 $if M % 2 == 1: 38 if XNN_UNPREDICTABLE(rows < ${M+1}) { 39 i${M} = zero; 40 } 41 $else: 42 if XNN_UNPREDICTABLE(rows <= ${M}) { 43 i${M} = zero; 44 } 45 46 const __m128i vbias = _mm_load_si128((const __m128i*) params->sse2.bias); 47 const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier); 48 const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding); 49 const __m128i vshift = _mm_loadl_epi64((const __m128i*) params->sse2.shift); 50 while (channels >= ${CHANNEL_TILE}) { 51 $for M in range(ROW_TILE): 52 $if SSE >= 4: 53 const __m128i vxi${M}x${ABC[0:8]} = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i${M})); 54 $for C in range(8, CHANNEL_TILE, 8): 55 const __m128i vxi${M}x${ABC[C:C+8]} = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) (i${M} + ${C}))); 56 $else: 57 const __m128i vi${M}x${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i${M}); 58 $for C in range(8, CHANNEL_TILE, 8): 59 const __m128i vi${M}x${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i${M} + ${C})); 60 i${M} += ${CHANNEL_TILE}; 61 62 $if SSE < 4: 63 $for M in range(ROW_TILE): 64 $for C in range(0, CHANNEL_TILE, 8): 65 const __m128i vxi${M}x${ABC[C:C+8]} = _mm_unpacklo_epi8(vi${M}x${ABC[C:C+8]}, _mm_cmpgt_epi8(_mm_setzero_si128(), vi${M}x${ABC[C:C+8]})); 66 67 $for A in range(ACCUMULATORS): 68 $for C in range(0, CHANNEL_TILE, 8): 69 __m128i vacc${A}x${ABC[C:C+8]} = _mm_add_epi16(vxi${A*2}x${ABC[C:C+8]}, vxi${A*2+1}x${ABC[C:C+8]}); 70 71 $for M in range(ACCUMULATORS * 2, ROW_TILE): 72 $for C in range(0, CHANNEL_TILE, 8): 73 vacc${M % ACCUMULATORS}x${ABC[C:C+8]} = _mm_add_epi16(vacc${M % ACCUMULATORS}x${ABC[C:C+8]}, vxi${M}x${ABC[C:C+8]}); 74 75 $if ACCUMULATORS > 1: 76 // Add up all accumulators to vacc0x${ABC[0:CHANNEL_TILE]} 77 $ACC_SLICE = 1 78 $while ACC_SLICE < ACCUMULATORS: 79 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 80 $if A + ACC_SLICE < ACCUMULATORS: 81 $for C in range(0, CHANNEL_TILE, 8): 82 vacc${A}x${ABC[C:C+8]} = _mm_add_epi16(vacc${A}x${ABC[C:C+8]}, vacc${A + ACC_SLICE}x${ABC[C:C+8]}); 83 $ACC_SLICE *= 2 84 85 $for C in range(0, CHANNEL_TILE, 8): 86 $if SSE >= 4: 87 const __m128i vacc${ABC[C:C+4]} = _mm_add_epi32(vbias, _mm_cvtepi16_epi32(vacc0x${ABC[C:C+8]})); 88 const __m128i vacc${ABC[C+4:C+8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vacc0x${ABC[C:C+8]}, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x${ABC[C:C+8]}))); 89 $else: 90 const __m128i vsgnacc0x${ABC[C:C+8]} = _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x${ABC[C:C+8]}); 91 const __m128i vacc${ABC[C:C+4]} = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vacc0x${ABC[C:C+8]}, vsgnacc0x${ABC[C:C+8]})); 92 const __m128i vacc${ABC[C+4:C+8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vacc0x${ABC[C:C+8]}, vsgnacc0x${ABC[C:C+8]})); 93 94 $if SSE >= 3: 95 $for C in range(0, CHANNEL_TILE, 4): 96 const __m128i vabsacc${ABC[C:C+4]} = _mm_abs_epi32(vacc${ABC[C:C+4]}); 97 $else: 98 $for C in range(0, CHANNEL_TILE, 4): 99 const __m128i vsgnacc${ABC[C:C+4]} = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc${ABC[C:C+4]}); 100 101 $for C in range(0, CHANNEL_TILE, 4): 102 const __m128i vabsacc${ABC[C:C+4]} = _mm_sub_epi32(_mm_xor_si128(vacc${ABC[C:C+4]}, vsgnacc${ABC[C:C+4]}), vsgnacc${ABC[C:C+4]}); 103 104 $for C in range(0, CHANNEL_TILE, 4): 105 const __m128i vabsacc${ABC[C+1:C+4:2]} = _mm_shuffle_epi32(vabsacc${ABC[C:C+4]}, _MM_SHUFFLE(3, 3, 1, 1)); 106 107 $for C in range(0, CHANNEL_TILE, 4): 108 const __m128i vabsprod${ABC[C:C+4:2]} = _mm_mul_epu32(vabsacc${ABC[C:C+4]}, vmultiplier); 109 const __m128i vabsprod${ABC[C+1:C+4:2]} = _mm_mul_epu32(vabsacc${ABC[C+1:C+4:2]}, vmultiplier); 110 111 $for C in range(0, CHANNEL_TILE, 4): 112 const __m128i vabsout${ABC[C:C+4:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[C:C+4:2]}, vrounding), vshift); 113 const __m128i vabsout${ABC[C+1:C+4:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[C+1:C+4:2]}, vrounding), vshift); 114 115 $if SSE >= 4: 116 $for C in range(0, CHANNEL_TILE, 4): 117 const __m128i vabsout${ABC[C:C+4]} = _mm_blend_epi16(vabsout${ABC[C:C+4:2]}, _mm_shuffle_epi32(vabsout${ABC[C+1:C+4:2]}, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC); 118 $else: 119 $for C in range(0, CHANNEL_TILE, 4): 120 const __m128i vabsout${ABC[C:C+4:2]}${ABC[C+1:C+4:2]} = _mm_castps_si128( 121 _mm_shuffle_ps(_mm_castsi128_ps(vabsout${ABC[C:C+4:2]}), _mm_castsi128_ps(vabsout${ABC[C+1:C+4:2]}), _MM_SHUFFLE(2, 0, 2, 0))); 122 123 $for C in range(0, CHANNEL_TILE, 4): 124 const __m128i vabsout${ABC[C:C+4]} = _mm_shuffle_epi32(vabsout${ABC[C:C+4:2]}${ABC[C+1:C+4:2]}, _MM_SHUFFLE(3, 1, 2, 0)); 125 126 $if SSE >= 3: 127 $for C in range(0, CHANNEL_TILE, 4): 128 const __m128i vout${ABC[C:C+4]} = _mm_sign_epi32(vabsout${ABC[C:C+4]}, vacc${ABC[C:C+4]}); 129 $else: 130 $for C in range(0, CHANNEL_TILE, 4): 131 const __m128i vout${ABC[C:C+4]} = _mm_sub_epi32(_mm_xor_si128(vabsout${ABC[C:C+4]}, vsgnacc${ABC[C:C+4]}), vsgnacc${ABC[C:C+4]}); 132 133 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point); 134 $for C in range(0, CHANNEL_TILE, 8): 135 __m128i vout${ABC[C:C+8]} = _mm_adds_epi16(_mm_packs_epi32(vout${ABC[C:C+4]}, vout${ABC[C+4:C+8]}), voutput_zero_point); 136 137 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min); 138 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max); 139 $for C in range(0, CHANNEL_TILE, 8): 140 vout${ABC[C:C+8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[C:C+8]}, voutput_min), voutput_max); 141 142 $for C in range(0, CHANNEL_TILE, 16): 143 $if C + 8 < CHANNEL_TILE: 144 __m128i vout${ABC[C:C+16]} = _mm_packs_epi16(vout${ABC[C:C+8]}, vout${ABC[C+8:C+16]}); 145 $else: 146 __m128i vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_packs_epi16(vout${ABC[C:C+8]}, vout${ABC[C:C+8]}); 147 148 $if CHANNEL_TILE > 8: 149 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 150 $else: 151 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 152 $for C in range(16, CHANNEL_TILE, 16): 153 $if C + 8 < CHANNEL_TILE: 154 _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]}); 155 $else: 156 _mm_storel_epi64((__m128i*) (output + ${C}), vout${ABC[C:C+8]}${ABC[C:C+8]}); 157 output += ${CHANNEL_TILE}; 158 159 channels -= ${CHANNEL_TILE}; 160 } 161 if XNN_UNLIKELY(channels != 0) { 162 ${"do " if CHANNEL_TILE > 8 else ""}{ 163 $for M in range(ROW_TILE): 164 $if SSE >= 4: 165 const __m128i vxi${M}x${ABC[0:8]} = _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) i${M})); 166 $else: 167 const __m128i vi${M}x${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i${M}); 168 i${M} += 8; 169 170 $if SSE < 4: 171 $for M in range(ROW_TILE): 172 const __m128i vxi${M}x${ABC[0:8]} = _mm_unpacklo_epi8(vi${M}x${ABC[0:8]}, _mm_cmpgt_epi8(_mm_setzero_si128(), vi${M}x${ABC[0:8]})); 173 174 $for A in range(ACCUMULATORS): 175 __m128i vacc${A}x${ABC[0:8]} = _mm_add_epi16(vxi${A*2}x${ABC[0:8]}, vxi${A*2+1}x${ABC[0:8]}); 176 177 $for M in range(ACCUMULATORS * 2, ROW_TILE): 178 vacc${M % ACCUMULATORS}x${ABC[0:8]} = _mm_add_epi16(vacc${M % ACCUMULATORS}x${ABC[0:8]}, vxi${M}x${ABC[0:8]}); 179 180 $if ACCUMULATORS > 1: 181 // Add up all accumulators to vacc0x${ABC[0:8]} 182 $ACC_SLICE = 1 183 $while ACC_SLICE < ACCUMULATORS: 184 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 185 $if A + ACC_SLICE < ACCUMULATORS: 186 vacc${A}x${ABC[0:8]} = _mm_add_epi16(vacc${A}x${ABC[0:8]}, vacc${A + ACC_SLICE}x${ABC[0:8]}); 187 $ACC_SLICE *= 2 188 189 $if SSE >= 4: 190 const __m128i vacc${ABC[0:4]} = _mm_add_epi32(vbias, _mm_cvtepi16_epi32(vacc0x${ABC[0:8]})); 191 const __m128i vacc${ABC[4:8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vacc0x${ABC[0:8]}, _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x${ABC[0:8]}))); 192 $else: 193 const __m128i vsgnacc0x${ABC[0:8]} = _mm_cmpgt_epi16(_mm_setzero_si128(), vacc0x${ABC[0:8]}); 194 const __m128i vacc${ABC[0:4]} = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vacc0x${ABC[0:8]}, vsgnacc0x${ABC[0:8]})); 195 const __m128i vacc${ABC[4:8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vacc0x${ABC[0:8]}, vsgnacc0x${ABC[0:8]})); 196 197 $if SSE >= 3: 198 const __m128i vabsacc${ABC[0:4]} = _mm_abs_epi32(vacc${ABC[0:4]}); 199 const __m128i vabsacc${ABC[4:8]} = _mm_abs_epi32(vacc${ABC[4:8]}); 200 $else: 201 const __m128i vsgnacc${ABC[0:4]} = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc${ABC[0:4]}); 202 const __m128i vsgnacc${ABC[4:8]} = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc${ABC[4:8]}); 203 204 const __m128i vabsacc${ABC[0:4]} = _mm_sub_epi32(_mm_xor_si128(vacc${ABC[0:4]}, vsgnacc${ABC[0:4]}), vsgnacc${ABC[0:4]}); 205 const __m128i vabsacc${ABC[4:8]} = _mm_sub_epi32(_mm_xor_si128(vacc${ABC[4:8]}, vsgnacc${ABC[4:8]}), vsgnacc${ABC[4:8]}); 206 207 const __m128i vabsacc${ABC[1:4:2]} = _mm_shuffle_epi32(vabsacc${ABC[0:4]}, _MM_SHUFFLE(3, 3, 1, 1)); 208 const __m128i vabsacc${ABC[5:8:2]} = _mm_shuffle_epi32(vabsacc${ABC[4:8]}, _MM_SHUFFLE(3, 3, 1, 1)); 209 210 const __m128i vabsprod${ABC[0:4:2]} = _mm_mul_epu32(vabsacc${ABC[0:4]}, vmultiplier); 211 const __m128i vabsprod${ABC[1:4:2]} = _mm_mul_epu32(vabsacc${ABC[1:4:2]}, vmultiplier); 212 const __m128i vabsprod${ABC[4:8:2]} = _mm_mul_epu32(vabsacc${ABC[4:8]}, vmultiplier); 213 const __m128i vabsprod${ABC[5:8:2]} = _mm_mul_epu32(vabsacc${ABC[5:8:2]}, vmultiplier); 214 215 const __m128i vabsout${ABC[0:4:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[0:4:2]}, vrounding), vshift); 216 const __m128i vabsout${ABC[1:4:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[1:4:2]}, vrounding), vshift); 217 const __m128i vabsout${ABC[4:8:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[4:8:2]}, vrounding), vshift); 218 const __m128i vabsout${ABC[5:8:2]} = _mm_srl_epi64(_mm_add_epi64(vabsprod${ABC[5:8:2]}, vrounding), vshift); 219 220 $if SSE >= 4: 221 const __m128i vabsout${ABC[0:4]} = _mm_blend_epi16(vabsout${ABC[0:4:2]}, _mm_shuffle_epi32(vabsout${ABC[1:4:2]}, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC); 222 const __m128i vabsout${ABC[4:8]} = _mm_blend_epi16(vabsout${ABC[4:8:2]}, _mm_shuffle_epi32(vabsout${ABC[5:8:2]}, _MM_SHUFFLE(2, 2, 0, 0)), 0xCC); 223 $else: 224 const __m128i vabsout${ABC[0:4:2]}${ABC[1:4:2]} = _mm_castps_si128( 225 _mm_shuffle_ps(_mm_castsi128_ps(vabsout${ABC[0:4:2]}), _mm_castsi128_ps(vabsout${ABC[1:4:2]}), _MM_SHUFFLE(2, 0, 2, 0))); 226 const __m128i vabsout${ABC[4:8:2]}${ABC[5:8:2]} = _mm_castps_si128( 227 _mm_shuffle_ps(_mm_castsi128_ps(vabsout${ABC[4:8:2]}), _mm_castsi128_ps(vabsout${ABC[5:8:2]}), _MM_SHUFFLE(2, 0, 2, 0))); 228 229 const __m128i vabsout${ABC[0:4]} = _mm_shuffle_epi32(vabsout${ABC[0:4:2]}${ABC[1:4:2]}, _MM_SHUFFLE(3, 1, 2, 0)); 230 const __m128i vabsout${ABC[4:8]} = _mm_shuffle_epi32(vabsout${ABC[4:8:2]}${ABC[5:8:2]}, _MM_SHUFFLE(3, 1, 2, 0)); 231 232 $if SSE >= 3: 233 const __m128i vout${ABC[0:4]} = _mm_sign_epi32(vabsout${ABC[0:4]}, vacc${ABC[0:4]}); 234 const __m128i vout${ABC[4:8]} = _mm_sign_epi32(vabsout${ABC[4:8]}, vacc${ABC[4:8]}); 235 $else: 236 const __m128i vout${ABC[0:4]} = _mm_sub_epi32(_mm_xor_si128(vabsout${ABC[0:4]}, vsgnacc${ABC[0:4]}), vsgnacc${ABC[0:4]}); 237 const __m128i vout${ABC[4:8]} = _mm_sub_epi32(_mm_xor_si128(vabsout${ABC[4:8]}, vsgnacc${ABC[4:8]}), vsgnacc${ABC[4:8]}); 238 239 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point); 240 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(vout${ABC[0:4]}, vout${ABC[4:8]}), voutput_zero_point); 241 242 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min); 243 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max); 244 vout${ABC[0:8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[0:8]}, voutput_min), voutput_max); 245 246 __m128i vout${ABC[0:8]}${ABC[0:8]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[0:8]}); 247 248 $if CHANNEL_TILE > 8: 249 if XNN_LIKELY(channels >= 8) { 250 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 251 output += 8; 252 channels -= 8; 253 } else { 254 if (channels & 4) { 255 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 256 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 257 output += 4; 258 } 259 if (channels & 2) { 260 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0); 261 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 262 output += 2; 263 } 264 if (channels & 1) { 265 $if SSE >= 4: 266 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 267 $else: 268 *output = (int32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 269 output += 1; 270 } 271 channels = 0; 272 } 273 $else: 274 if (channels & 4) { 275 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 276 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 277 output += 4; 278 } 279 if (channels & 2) { 280 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0); 281 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 282 output += 2; 283 } 284 if (channels & 1) { 285 $if SSE >= 4: 286 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 287 $else: 288 *output = (int32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 289 } 290 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 291 } 292} 293