1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$assert RR_STEPS in [1, 2] 9$assert DIV_ALGO in ["div", "nr1fma", "nr2fma"] 10$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 11$SIMD_TILE = BATCH_TILE // 8 12#include <assert.h> 13 14#include <immintrin.h> 15 16#include <xnnpack/common.h> 17#include <xnnpack/vunary.h> 18 19 20static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0}; 21 22void xnn_f32_sigmoid_ukernel__avx2_rr${RR_STEPS}_p5_${DIV_ALGO}_x${BATCH_TILE}( 23 size_t n, 24 const float* x, 25 float* y, 26 const void* params) 27{ 28 assert(n % sizeof(float) == 0); 29 30 const __m256 vsign_mask = _mm256_set1_ps(-0.0f); 31 const __m256 vmagic_bias = _mm256_set1_ps(0x1.8000FEp23f); 32 const __m256 vlog2e = _mm256_set1_ps(0x1.715476p0f); 33 $if RR_STEPS == 1: 34 const __m256 vminus_ln2 = _mm256_set1_ps(-0x1.62E43p-1f); 35 $else: 36 const __m256 vminus_ln2_hi = _mm256_set1_ps(-0x1.62E43p-1f); 37 const __m256 vminus_ln2_lo = _mm256_set1_ps(0x1.05C61p-29f); 38 const __m256 vc5 = _mm256_set1_ps(0x1.0F9F9Cp-7f); 39 const __m256 vc4 = _mm256_set1_ps(0x1.573A1Ap-5f); 40 const __m256 vc3 = _mm256_set1_ps(0x1.555A80p-3f); 41 const __m256 vc2 = _mm256_set1_ps(0x1.FFFDC6p-2f); 42 const __m256 vc1 = _mm256_set1_ps(0x1.FFFFF6p-1f); 43 const __m256 vone = _mm256_set1_ps(1.0f); 44 const __m256 vdenorm_cutoff = _mm256_set1_ps(-0x1.5D589Ep+6f); 45 46 $if BATCH_TILE > 8: 47 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 48 const __m256 vx${ABC[0]} = _mm256_loadu_ps(x); 49 $for N in range(1, SIMD_TILE): 50 const __m256 vx${ABC[N]} = _mm256_loadu_ps(x + ${N * 8}); 51 x += ${BATCH_TILE}; 52 53 $for N in range(SIMD_TILE): 54 const __m256 vz${ABC[N]} = _mm256_or_ps(vx${ABC[N]}, vsign_mask); 55 56 $for N in range(SIMD_TILE): 57 __m256 vn${ABC[N]} = _mm256_fmadd_ps(vz${ABC[N]}, vlog2e, vmagic_bias); 58 59 $for N in range(SIMD_TILE): 60 const __m256 vs${ABC[N]} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${ABC[N]}), 23)); 61 62 $for N in range(SIMD_TILE): 63 vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias); 64 65 $if RR_STEPS == 1: 66 $for N in range(SIMD_TILE): 67 __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2, vz${ABC[N]}); 68 $else: 69 $for N in range(SIMD_TILE): 70 __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2_hi, vz${ABC[N]}); 71 72 $for N in range(SIMD_TILE): 73 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2_lo, vt${ABC[N]}); 74 75 $for N in range(SIMD_TILE): 76 __m256 vp${ABC[N]} = _mm256_fmadd_ps(vc5, vt${ABC[N]}, vc4); 77 78 $for N in range(SIMD_TILE): 79 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc3); 80 81 $for N in range(SIMD_TILE): 82 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc2); 83 84 $for N in range(SIMD_TILE): 85 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc1); 86 87 $for N in range(SIMD_TILE): 88 vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]}); 89 90 $for N in range(SIMD_TILE): 91 const __m256 ve${ABC[N]} = _mm256_fmadd_ps(vt${ABC[N]}, vp${ABC[N]}, vs${ABC[N]}); 92 93 $for N in range(SIMD_TILE): 94 const __m256 vd${ABC[N]} = _mm256_add_ps(ve${ABC[N]}, vone); 95 96 $if DIV_ALGO == "div": 97 $for N in range(SIMD_TILE): 98 __m256 vf${ABC[N]} = _mm256_div_ps(ve${ABC[N]}, vd${ABC[N]}); 99 $else: 100 $for N in range(SIMD_TILE): 101 __m256 vr${ABC[N]} = _mm256_rcp_ps(vd${ABC[N]}); 102 103 $for N in range(SIMD_TILE): 104 vr${ABC[N]} = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr${ABC[N]}, vd${ABC[N]}, vone), vr${ABC[N]}, vr${ABC[N]}); 105 106 $if DIV_ALGO == "nr2fma": 107 $for N in range(SIMD_TILE): 108 vr${ABC[N]} = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr${ABC[N]}, vd${ABC[N]}, vone), vr${ABC[N]}, vr${ABC[N]}); 109 110 $for N in range(SIMD_TILE): 111 __m256 vf${ABC[N]} = _mm256_mul_ps(ve${ABC[N]}, vr${ABC[N]}); 112 113 $for N in range(SIMD_TILE): 114 vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vz${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]}); 115 116 $for N in range(SIMD_TILE): 117 vf${ABC[N]} = _mm256_blendv_ps(_mm256_sub_ps(vone, vf${ABC[N]}), vf${ABC[N]}, vx${ABC[N]}); 118 119 _mm256_storeu_ps(y, vf${ABC[0]}); 120 $for N in range(1, SIMD_TILE): 121 _mm256_storeu_ps(y + ${N * 8}, vf${ABC[N]}); 122 y += ${BATCH_TILE}; 123 } 124 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 125 const __m256 vx = _mm256_loadu_ps(x); 126 x += 8; 127 128 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 129 130 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 131 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 132 vn = _mm256_sub_ps(vn, vmagic_bias); 133 134 $if RR_STEPS == 1: 135 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 136 $else: 137 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2_hi, vz); 138 vt = _mm256_fmadd_ps(vn, vminus_ln2_lo, vt); 139 140 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4); 141 vp = _mm256_fmadd_ps(vp, vt, vc3); 142 vp = _mm256_fmadd_ps(vp, vt, vc2); 143 vp = _mm256_fmadd_ps(vp, vt, vc1); 144 145 vt = _mm256_mul_ps(vt, vs); 146 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 147 148 const __m256 vd = _mm256_add_ps(ve, vone); 149 $if DIV_ALGO == "div": 150 __m256 vf = _mm256_div_ps(ve, vd); 151 $else: 152 __m256 vr = _mm256_rcp_ps(vd); 153 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 154 $if DIV_ALGO == "nr2fma": 155 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 156 __m256 vf = _mm256_mul_ps(ve, vr); 157 158 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 159 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 160 161 _mm256_storeu_ps(y, vf); 162 y += 8; 163 } 164 if XNN_UNLIKELY(n != 0) { 165 assert(n >= 1 * sizeof(float)); 166 assert(n <= 7 * sizeof(float)); 167 __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n)); 168 169 const __m256 vx = _mm256_maskload_ps(x, vmask); 170 171 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 172 173 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 174 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 175 vn = _mm256_sub_ps(vn, vmagic_bias); 176 177 $if RR_STEPS == 1: 178 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 179 $else: 180 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2_hi, vz); 181 vt = _mm256_fmadd_ps(vn, vminus_ln2_lo, vt); 182 183 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4); 184 vp = _mm256_fmadd_ps(vp, vt, vc3); 185 vp = _mm256_fmadd_ps(vp, vt, vc2); 186 vp = _mm256_fmadd_ps(vp, vt, vc1); 187 188 vt = _mm256_mul_ps(vt, vs); 189 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 190 191 const __m256 vd = _mm256_add_ps(ve, vone); 192 $if DIV_ALGO == "div": 193 __m256 vf = _mm256_div_ps(ve, vd); 194 $else: 195 __m256 vr = _mm256_rcp_ps(vd); 196 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 197 $if DIV_ALGO == "nr2fma": 198 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 199 __m256 vf = _mm256_mul_ps(ve, vr); 200 201 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 202 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 203 204 // _mm256_maskstore_ps(y, vmask, vf) could be used here, but triggers msan failures (probably an msan bug). 205 __m128 vf_lo = _mm256_castps256_ps128(vf); 206 if (n & (4 * sizeof(float))) { 207 _mm_storeu_ps(y, vf_lo); 208 vf_lo = _mm256_extractf128_ps(vf, 1); 209 y += 4; 210 } 211 if (n & (2 * sizeof(float))) { 212 _mm_storel_pi((__m64*) y, vf_lo); 213 vf_lo = _mm_movehl_ps(vf_lo, vf_lo); 214 y += 2; 215 } 216 if (n & (1 * sizeof(float))) { 217 _mm_store_ss(y, vf_lo); 218 } 219 } 220} 221