1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ELEMENTS_TILE % 8 == 0
7$assert ELEMENTS_TILE >= 8
8$SIMD_TILE = ELEMENTS_TILE // 8
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/common.h>
15#include <xnnpack/vscaleexpminusmax.h>
16
17
18static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};
19
20void xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_x${ELEMENTS_TILE}(
21    size_t elements,
22    const float* input,
23    float* output,
24    float scale,
25    float max)
26{
27  assert(elements % sizeof(float) == 0);
28
29  const __m256 vmagic_bias = _mm256_set1_ps(0x1.8000FEp23f);
30  // The smallest x for which expf(x) is normalized.
31  const __m256 vdenorm_cutoff = _mm256_set1_ps(-0x1.5D589Ep6f);
32  const __m256 vlog2e = _mm256_set1_ps(0x1.715476p+0f);
33  const __m256 vminus_ln2_hi = _mm256_set1_ps(-0x1.62E43p-1f);
34  const __m256 vminus_ln2_lo = _mm256_set1_ps(0x1.05C61p-29f);
35
36  const __m256 vc1 = _mm256_set1_ps(0x1.FFFFF6p-1f);
37  const __m256 vc2 = _mm256_set1_ps(0x1.FFFDC6p-2f);
38  const __m256 vc3 = _mm256_set1_ps(0x1.555A80p-3f);
39  const __m256 vc4 = _mm256_set1_ps(0x1.573A1Ap-5f);
40  const __m256 vc5 = _mm256_set1_ps(0x1.0F9F9Cp-7f);
41
42  const __m256 vscale = _mm256_set1_ps(scale);
43  const __m256 vi_max = _mm256_set1_ps(max);
44
45  for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) {
46    // Load ${ELEMENTS_TILE} (${SIMD_TILE}x8) inputs at a time.
47    const __m256 vi0 = _mm256_loadu_ps(input);
48    $for N in range(1, SIMD_TILE):
49      const __m256 vi${N} = _mm256_loadu_ps(input + ${N * 8});
50    input += ${ELEMENTS_TILE};
51
52    // Subtract maximum input x := i - i_max. This implies x <= 0.
53    $for N in range(SIMD_TILE):
54      const __m256 vx${N} = _mm256_sub_ps(vi${N}, vi_max);
55
56    // Compute reduced argument elements := round(x / log(2)).
57    $for N in range(SIMD_TILE):
58      __m256 vn${N} = _mm256_fmadd_ps(vx${N}, vlog2e, vmagic_bias);
59
60    // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e.
61    // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly.
62    $for N in range(SIMD_TILE):
63      const __m256 vs${N} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${N}), 23));
64
65    // Subtract the large number back to get final elements := round(x / log(2)).
66    $for N in range(SIMD_TILE):
67      vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias);
68
69    // Compute reduced argument t := x - elements * log(2).
70    // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy.
71    $for N in range(SIMD_TILE):
72      __m256 vt${N} = _mm256_fmadd_ps(vn${N}, vminus_ln2_hi, vx${N});
73
74    $for N in range(SIMD_TILE):
75      vt${N} = _mm256_fmadd_ps(vn${N}, vminus_ln2_lo, vt${N});
76
77    // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2].
78    $for N in range(SIMD_TILE):
79      __m256 vp${N} = _mm256_fmadd_ps(vc5, vt${N}, vc4);
80
81    $for N in range(SIMD_TILE):
82      vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc3);
83
84    $for N in range(SIMD_TILE):
85      vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc2);
86
87    $for N in range(SIMD_TILE):
88      vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc1);
89
90    // Reconstruct the final f value:
91    //   f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))))
92    //     = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))
93    //     = s + (t * s) * p
94    $for N in range(SIMD_TILE):
95      vt${N} = _mm256_mul_ps(vt${N}, vs${N});
96
97    $for N in range(SIMD_TILE):
98      __m256 vf${N} = _mm256_fmadd_ps(vt${N}, vp${N}, vs${N});
99
100    // For inputs below zero cutoff, replace output with +0.0f.
101    // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
102    $for N in range(SIMD_TILE):
103      vf${N} = _mm256_andnot_ps(_mm256_cmp_ps(vx${N}, vdenorm_cutoff, _CMP_LT_OS), vf${N});
104
105    // Multiply by scale.
106    $for N in range(SIMD_TILE):
107      vf${N} = _mm256_mul_ps(vf${N}, vscale);
108
109    // Store ${ELEMENTS_TILE} (${SIMD_TILE}x8) outputs at a time.
110    _mm256_storeu_ps(output, vf0);
111    $for N in range(1, SIMD_TILE):
112      _mm256_storeu_ps(output + ${N * 8}, vf${N});
113    output += ${ELEMENTS_TILE};
114  }
115  for (; elements >= 8 * sizeof(float); elements -= 8 * sizeof(float)) {
116    // Load 8 inputs at a time.
117    const __m256 vi = _mm256_loadu_ps(input);
118    input += 8;
119
120    // Subtract maximum input x := i - i_max. This implies x <= 0.
121    const __m256 vx = _mm256_sub_ps(vi, vi_max);
122
123    // Compute reduced argument elements := round(x / log(2)).
124    __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
125
126    // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e.
127    // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly.
128    const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
129
130    // Subtract the large number back to get final elements := round(x / log(2)).
131    vn = _mm256_sub_ps(vn, vmagic_bias);
132
133    // Compute reduced argument t := x - elements * log(2).
134    // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy.
135    __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2_hi, vx);
136    vt = _mm256_fmadd_ps(vn, vminus_ln2_lo, vt);
137
138    // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2].
139    __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4);
140    vp = _mm256_fmadd_ps(vp, vt, vc3);
141    vp = _mm256_fmadd_ps(vp, vt, vc2);
142    vp = _mm256_fmadd_ps(vp, vt, vc1);
143
144    // Reconstruct the final f value:
145    //   f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))))
146    //     = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))
147    //     = s + (t * s) * p
148    vt = _mm256_mul_ps(vt, vs);
149    __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
150
151    // For inputs below zero cutoff, replace output with +0.0f.
152    // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
153    vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
154
155    // Multiply by scale.
156    vf = _mm256_mul_ps(vf, vscale);
157
158    // Store 64 (8x8) outputs at a time.
159    _mm256_storeu_ps(output, vf);
160    output += 8;
161  }
162  if (elements != 0) {
163    assert(elements >= 1 * sizeof(float));
164    assert(elements <= 7 * sizeof(float));
165    const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - elements));
166
167    // Load up to 7 inputs at a time.
168    const __m256 vi = _mm256_maskload_ps(input, vmask);
169
170    // Subtract maximum input x := i - i_max. This implies x <= 0.
171    const __m256 vx = _mm256_sub_ps(vi, vi_max);
172
173    // Compute reduced argument elements := round(x / log(2)).
174    __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
175
176    // Create a floating-point number s (scale) such that s == 2**elements for inputs which don't cause underflow, i.e.
177    // -87.33642 <= x <= 0.0, and -126 <= elements <= 0 accordingly.
178    const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
179
180    // Subtract the large number back to get final elements := round(x / log(2)).
181    vn = _mm256_sub_ps(vn, vmagic_bias);
182
183    // Compute reduced argument t := x - elements * log(2).
184    // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy.
185    __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2_hi, vx);
186    vt = _mm256_fmadd_ps(vn, vminus_ln2_lo, vt);
187
188    // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2].
189    __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4);
190    vp = _mm256_fmadd_ps(vp, vt, vc3);
191    vp = _mm256_fmadd_ps(vp, vt, vc2);
192    vp = _mm256_fmadd_ps(vp, vt, vc1);
193
194    // Reconstruct the final f value:
195    //   f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))))
196    //     = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))
197    //     = s + (t * s) * p
198    vt = _mm256_mul_ps(vt, vs);
199    __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
200
201    // For inputs below zero cutoff, replace output with +0.0f.
202    // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
203    vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
204
205    // Multiply by scale.
206    vf = _mm256_mul_ps(vf, vscale);
207
208    // Store up to 7 outputs at a time.
209    _mm256_maskstore_ps(output, vmask, vf);
210  }
211}
212