1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert ELEMENTS_TILE % 4 == 0 7$assert ELEMENTS_TILE >= 4 8$SIMD_TILE = ELEMENTS_TILE // 4 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <wasm_simd128.h> 13 14#include <xnnpack/common.h> 15#include <xnnpack/raddstoreexpminusmax.h> 16 17 18void xnn_f32_raddstoreexpminusmax_ukernel__wasmsimd_p5_x${ELEMENTS_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 19 size_t elements, 20 const float* input, 21 float* output, 22 float* sum, 23 float max) XNN_DISABLE_TSAN 24{ 25 assert(elements % sizeof(float) == 0); 26 27 const v128_t vmagic_bias = wasm_f32x4_splat(0x1.8000FEp23f); 28 // The smallest x for which expf(x) is normalized. 29 const v128_t vdenorm_cutoff = wasm_f32x4_splat(-0x1.5D589Ep6f); 30 const v128_t vlog2e = wasm_f32x4_splat(0x1.715476p+0f); 31 // Last 7 bits are zeroes 32 const v128_t vminus_ln2_hi = wasm_f32x4_splat(-0x1.62E400p-1f); 33 const v128_t vminus_ln2_lo = wasm_f32x4_splat(-0x1.7F7D1Cp-20f); 34 35 const v128_t vc1 = wasm_f32x4_splat(0x1.FFFFF6p-1f); 36 const v128_t vc2 = wasm_f32x4_splat(0x1.FFFDC6p-2f); 37 const v128_t vc3 = wasm_f32x4_splat(0x1.555A80p-3f); 38 const v128_t vc4 = wasm_f32x4_splat(0x1.573A1Ap-5f); 39 const v128_t vc5 = wasm_f32x4_splat(0x1.0F9F9Cp-7f); 40 41 const v128_t vi_max = wasm_f32x4_splat(max); 42 43 v128_t vacc0 = wasm_f32x4_splat(0.0f); 44 $for K in range(1, ACCUMULATORS): 45 v128_t vacc${K} = vacc0; 46 for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) { 47 // Load ${ELEMENTS_TILE} (${SIMD_TILE}x4) inputs at a time. 48 const v128_t vi${ABC[0:4]} = wasm_v128_load(input); 49 $for N in range(4, ELEMENTS_TILE, 4): 50 const v128_t vi${ABC[N:N+4]} = wasm_v128_load(input + ${N}); 51 input += ${ELEMENTS_TILE}; 52 53 // Subtract maximum input x := i - i_max. This implies x <= 0. 54 $for N in range(0, ELEMENTS_TILE, 4): 55 const v128_t vx${ABC[N:N+4]} = wasm_f32x4_sub(vi${ABC[N:N+4]}, vi_max); 56 57 // Compute reduced argument elements := round(x / log(2)). 58 $for N in range(0, ELEMENTS_TILE, 4): 59 v128_t vn${ABC[N:N+4]} = wasm_f32x4_add(vmagic_bias, wasm_f32x4_mul(vx${ABC[N:N+4]}, vlog2e)); 60 61 // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e. 62 // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly. 63 $for N in range(0, ELEMENTS_TILE, 4): 64 const v128_t vs${ABC[N:N+4]} = wasm_i32x4_shl(vn${ABC[N:N+4]}, 23); 65 66 // Subtract the large number back to get final elements := round(x / log(2)). 67 $for N in range(0, ELEMENTS_TILE, 4): 68 vn${ABC[N:N+4]} = wasm_f32x4_sub(vn${ABC[N:N+4]}, vmagic_bias); 69 70 // Compute reduced argument t := x - elements * log(2). 71 // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 72 $for N in range(0, ELEMENTS_TILE, 4): 73 v128_t vt${ABC[N:N+4]} = wasm_f32x4_add(vx${ABC[N:N+4]}, wasm_f32x4_mul(vn${ABC[N:N+4]}, vminus_ln2_hi)); 74 75 $for N in range(0, ELEMENTS_TILE, 4): 76 vt${ABC[N:N+4]} = wasm_f32x4_add(vt${ABC[N:N+4]}, wasm_f32x4_mul(vn${ABC[N:N+4]}, vminus_ln2_lo)); 77 78 // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 79 $for N in range(0, ELEMENTS_TILE, 4): 80 v128_t vp${ABC[N:N+4]} = wasm_f32x4_add(vc4, wasm_f32x4_mul(vc5, vt${ABC[N:N+4]})); 81 82 $for N in range(0, ELEMENTS_TILE, 4): 83 vp${ABC[N:N+4]} = wasm_f32x4_add(vc3, wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]})); 84 85 $for N in range(0, ELEMENTS_TILE, 4): 86 vp${ABC[N:N+4]} = wasm_f32x4_add(vc2, wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]})); 87 88 $for N in range(0, ELEMENTS_TILE, 4): 89 vp${ABC[N:N+4]} = wasm_f32x4_add(vc1, wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]})); 90 91 // Reconstruct the final f value: 92 // f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))) 93 // = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))) 94 // = s + (t * s) * p 95 $for N in range(0, ELEMENTS_TILE, 4): 96 vt${ABC[N:N+4]} = wasm_f32x4_mul(vt${ABC[N:N+4]}, vs${ABC[N:N+4]}); 97 98 $for N in range(0, ELEMENTS_TILE, 4): 99 v128_t vf${ABC[N:N+4]} = wasm_f32x4_add(vs${ABC[N:N+4]}, wasm_f32x4_mul(vt${ABC[N:N+4]}, vp${ABC[N:N+4]})); 100 101 // For inputs below zero cutoff, replace output with +0.0f. 102 // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. 103 $for N in range(0, ELEMENTS_TILE, 4): 104 vf${ABC[N:N+4]} = wasm_v128_andnot(vf${ABC[N:N+4]}, wasm_f32x4_lt(vx${ABC[N:N+4]}, vdenorm_cutoff)); 105 106 // Store ${ELEMENTS_TILE} (${SIMD_TILE}x4) outputs at a time. 107 wasm_v128_store(output, vf${ABC[0:4]}); 108 $for N in range(4, ELEMENTS_TILE, 4): 109 wasm_v128_store(output + ${N}, vf${ABC[N:N+4]}); 110 output += ${ELEMENTS_TILE}; 111 112 // Accumulate computed exponents. 113 $for N in range(0, ELEMENTS_TILE, 4): 114 vacc${N % ACCUMULATORS} = wasm_f32x4_add(vacc${N % ACCUMULATORS}, vf${ABC[N:N+4]}); 115 } 116 $if ACCUMULATORS > 1: 117 // Add up all accumulators to vacc0 118 $ACC_SLICE = 1 119 $while ACC_SLICE < ACCUMULATORS: 120 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 121 $if A + ACC_SLICE < ACCUMULATORS: 122 vacc${A} = wasm_f32x4_add(vacc${A}, vacc${A + ACC_SLICE}); 123 $ACC_SLICE *= 2 124 125 v128_t vacc = vacc0; 126 for (; elements >= 4 * sizeof(float); elements -= 4 * sizeof(float)) { 127 // Load 4 inputs at a time. 128 const v128_t vi = wasm_v128_load(input); 129 input += 4; 130 131 // Subtract maximum input x := i - i_max. This implies x <= 0. 132 const v128_t vx = wasm_f32x4_sub(vi, vi_max); 133 134 // Compute reduced argument elements := round(x / log(2)). 135 v128_t vn = wasm_f32x4_add(vmagic_bias, wasm_f32x4_mul(vx, vlog2e)); 136 137 // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e. 138 // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly. 139 const v128_t vs = wasm_i32x4_shl(vn, 23); 140 141 // Subtract the large number back to get final elements := round(x / log(2)). 142 vn = wasm_f32x4_sub(vn, vmagic_bias); 143 144 // Compute reduced argument t := x - elements * log(2). 145 // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 146 v128_t vt = wasm_f32x4_add(vx, wasm_f32x4_mul(vn, vminus_ln2_hi)); 147 vt = wasm_f32x4_add(vt, wasm_f32x4_mul(vn, vminus_ln2_lo)); 148 149 // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 150 v128_t vp = wasm_f32x4_add(vc4, wasm_f32x4_mul(vc5, vt)); 151 vp = wasm_f32x4_add(vc3, wasm_f32x4_mul(vp, vt)); 152 vp = wasm_f32x4_add(vc2, wasm_f32x4_mul(vp, vt)); 153 vp = wasm_f32x4_add(vc1, wasm_f32x4_mul(vp, vt)); 154 155 // Reconstruct the final f value: 156 // f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))) 157 // = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))) 158 // = s + (t * s) * p 159 vt = wasm_f32x4_mul(vt, vs); 160 v128_t vf = wasm_f32x4_add(vs, wasm_f32x4_mul(vt, vp)); 161 162 // For inputs below zero cutoff, replace output with +0.0f. 163 // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. 164 vf = wasm_v128_andnot(vf, wasm_f32x4_lt(vx, vdenorm_cutoff)); 165 166 // Store 4 outputs at a time. 167 wasm_v128_store(output, vf); 168 output += 4; 169 170 // Accumulate computed exponents. 171 vacc = wasm_f32x4_add(vacc, vf); 172 } 173 vacc = wasm_f32x4_add(vacc, wasm_v32x4_shuffle(vacc, vacc, 2, 3, 2, 3)); 174 float vsum = wasm_f32x4_extract_lane(vacc, 0) + wasm_f32x4_extract_lane(vacc, 1); 175 if (elements != 0) { 176 assert(elements >= 1 * sizeof(float)); 177 assert(elements <= 3 * sizeof(float)); 178 // Load 4 inputs at a time. 179 const v128_t vi = wasm_v128_load(input); 180 181 // Subtract maximum input x := i - i_max. This implies x <= 0. 182 const v128_t vx = wasm_f32x4_sub(vi, vi_max); 183 184 // Compute reduced argument elements := round(x / log(2)). 185 v128_t vn = wasm_f32x4_add(vmagic_bias, wasm_f32x4_mul(vx, vlog2e)); 186 187 // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e. 188 // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly. 189 const v128_t vs = wasm_i32x4_shl(vn, 23); 190 191 // Subtract the large number back to get final elements := round(x / log(2)). 192 vn = wasm_f32x4_sub(vn, vmagic_bias); 193 194 // Compute reduced argument t := x - elements * log(2). 195 // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 196 v128_t vt = wasm_f32x4_add(vx, wasm_f32x4_mul(vn, vminus_ln2_hi)); 197 vt = wasm_f32x4_add(vt, wasm_f32x4_mul(vn, vminus_ln2_lo)); 198 199 // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 200 v128_t vp = wasm_f32x4_add(vc4, wasm_f32x4_mul(vc5, vt)); 201 vp = wasm_f32x4_add(vc3, wasm_f32x4_mul(vp, vt)); 202 vp = wasm_f32x4_add(vc2, wasm_f32x4_mul(vp, vt)); 203 vp = wasm_f32x4_add(vc1, wasm_f32x4_mul(vp, vt)); 204 205 // Reconstruct the final f value: 206 // f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))) 207 // = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))) 208 // = s + (t * s) * p 209 vt = wasm_f32x4_mul(vt, vs); 210 v128_t vf = wasm_f32x4_add(vs, wasm_f32x4_mul(vt, vp)); 211 212 // For inputs below zero cutoff, replace output with +0.0f. 213 // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. 214 vf = wasm_v128_andnot(vf, wasm_f32x4_lt(vx, vdenorm_cutoff)); 215 216 if (elements & (2 * sizeof(float))) { 217 // Store and accumulate 2 outputs at a time. 218 const float vf0 = wasm_f32x4_extract_lane(vf, 0); 219 output[0] = vf0; 220 vsum += vf0; 221 222 const float vf1 = wasm_f32x4_extract_lane(vf, 1); 223 output[1] = vf1; 224 vsum += vf1; 225 226 vf = wasm_v32x4_shuffle(vf, vf, 2, 3, 2, 3); 227 output += 2; 228 } 229 if (elements & (1 * sizeof(float))) { 230 // Store 1 output at a time. 231 const float vf0 = wasm_f32x4_extract_lane(vf, 0); 232 *output = vf0; 233 vsum += vf0; 234 } 235 } 236 // Reduce 4 elements in the SIMD register 237 *sum = vsum; 238} 239