1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9#include <assert.h> 10 11#include <immintrin.h> 12 13#include <xnnpack/common.h> 14#include <xnnpack/hswish.h> 15 16 17static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0}; 18 19$ISA = {0: "avx", 3: "fma3"}[FMA] 20void xnn_f32_hswish_ukernel__${ISA}_x${BATCH_TILE}( 21 size_t n, 22 const float* x, 23 float* y, 24 const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)]) 25{ 26 assert(n != 0); 27 assert(n % sizeof(float) == 0); 28 29 const __m256 vsixth = _mm256_broadcast_ps((const __m128*) params->sse.sixth); 30 const __m256 vhalf = _mm256_broadcast_ps((const __m128*) params->sse.half); 31 const __m256 vone = _mm256_broadcast_ps((const __m128*) params->sse.one); 32 const __m256 vzero = _mm256_setzero_ps(); 33 34 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 35 const __m256 vx${ABC[0:8]} = _mm256_loadu_ps(x); 36 $for N in range(8, BATCH_TILE, 8): 37 const __m256 vx${ABC[N:N+8]} = _mm256_loadu_ps(x + ${N}); 38 x += ${BATCH_TILE}; 39 40 $if FMA == 3: 41 $for N in range(0, BATCH_TILE, 8): 42 __m256 vacc${ABC[N:N+8]} = _mm256_fmadd_ps(vx${ABC[N:N+8]}, vsixth, vhalf); 43 $else: 44 $for N in range(0, BATCH_TILE, 8): 45 __m256 vacc${ABC[N:N+8]} = _mm256_mul_ps(vx${ABC[N:N+8]}, vsixth); 46 47 $for N in range(0, BATCH_TILE, 8): 48 vacc${ABC[N:N+8]} = _mm256_add_ps(vacc${ABC[N:N+8]}, vhalf); 49 50 $for N in range(0, BATCH_TILE, 8): 51 vacc${ABC[N:N+8]} = _mm256_max_ps(vacc${ABC[N:N+8]}, vzero); 52 53 $for N in range(0, BATCH_TILE, 8): 54 vacc${ABC[N:N+8]} = _mm256_min_ps(vacc${ABC[N:N+8]}, vone); 55 56 $for N in range(0, BATCH_TILE, 8): 57 vacc${ABC[N:N+8]} = _mm256_mul_ps(vacc${ABC[N:N+8]}, vx${ABC[N:N+8]}); 58 59 _mm256_storeu_ps(y, vacc${ABC[0:8]}); 60 $for N in range(8, BATCH_TILE, 8): 61 _mm256_storeu_ps(y + ${N}, vacc${ABC[N:N+8]}); 62 y += ${BATCH_TILE}; 63 } 64 $if BATCH_TILE > 8: 65 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 66 const __m256 vx = _mm256_loadu_ps(x); 67 x += 8; 68 $if FMA == 3: 69 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf); 70 $else: 71 __m256 vacc = _mm256_mul_ps(vx, vsixth); 72 vacc = _mm256_add_ps(vacc, vhalf); 73 vacc = _mm256_max_ps(vacc, vzero); 74 vacc = _mm256_min_ps(vacc, vone); 75 vacc = _mm256_mul_ps(vacc, vx); 76 _mm256_storeu_ps(y, vacc); 77 y += 8; 78 } 79 if XNN_UNLIKELY(n != 0) { 80 assert(n >= 1 * sizeof(float)); 81 assert(n <= 7 * sizeof(float)); 82 __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n)); 83 84 const __m256 vx = _mm256_maskload_ps(x, vmask); 85 $if FMA == 3: 86 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf); 87 $else: 88 __m256 vacc = _mm256_mul_ps(vx, vsixth); 89 vacc = _mm256_add_ps(vacc, vhalf); 90 vacc = _mm256_max_ps(vacc, vzero); 91 vacc = _mm256_min_ps(vacc, vone); 92 vacc = _mm256_mul_ps(vacc, vx); 93 94 // _mm256_maskstore_ps(y, vmask, vacc) could be used here, but triggers msan failures (probably an msan bug). 95 __m128 vacc_lo = _mm256_castps256_ps128(vacc); 96 if (n & (4 * sizeof(float))) { 97 _mm_storeu_ps(y, vacc_lo); 98 vacc_lo = _mm256_extractf128_ps(vacc, 1); 99 y += 4; 100 } 101 if (n & (2 * sizeof(float))) { 102 _mm_storel_pi((__m64*) y, vacc_lo); 103 vacc_lo = _mm_movehl_ps(vacc_lo, vacc_lo); 104 y += 2; 105 } 106 if (n & (1 * sizeof(float))) { 107 _mm_store_ss(y, vacc_lo); 108 } 109 } 110} 111