1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$assert SSE in [1, 2, 4]
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10$SSE_HEADER = {1: "xmmintrin.h", 2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
11#include <assert.h>
12
13#include <${SSE_HEADER}>
14
15#include <xnnpack/common.h>
16#include <xnnpack/vunary.h>
17
18
19$ISA = {1: "sse", 2: "sse2", 4: "sse41"}[SSE]
20void xnn_f32_vlrelu_ukernel__${ISA}_x${BATCH_TILE}(
21    size_t n,
22    const float* x,
23    float* y,
24    const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
25{
26  assert(n != 0);
27  assert(n % sizeof(float) == 0);
28
29  const __m128 vslope = _mm_load_ps(params->sse.slope);
30  $if SSE == 1:
31    const __m128 vzero = _mm_setzero_ps();
32  for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
33    $if SSE == 1:
34      __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
35      $for N in range(4, BATCH_TILE, 4):
36        __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
37    $else:
38      const __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
39      $for N in range(4, BATCH_TILE, 4):
40        const __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
41    x += ${BATCH_TILE};
42
43    $for N in range(0, BATCH_TILE, 4):
44      $if SSE == 1:
45        __m128 vacc${ABC[N:N+4]} = _mm_max_ps(_mm_setzero_ps(), vx${ABC[N:N+4]});
46        vx${ABC[N:N+4]} = _mm_min_ps(vx${ABC[N:N+4]}, vzero);
47      $else:
48        __m128 vacc${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vslope);
49        $if SSE == 2:
50          const __m128 vmask${ABC[N:N+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx${ABC[N:N+4]})));
51
52    $for N in range(0, BATCH_TILE, 4):
53      $if SSE == 1:
54        vacc${ABC[N:N+4]} = _mm_add_ps(vacc${ABC[N:N+4]}, _mm_mul_ps(vx${ABC[N:N+4]}, vslope));
55      $elif SSE == 2:
56        vacc${ABC[N:N+4]} = _mm_or_ps(_mm_and_ps(vacc${ABC[N:N+4]}, vmask${ABC[N:N+4]}), _mm_andnot_ps(vmask${ABC[N:N+4]}, vx${ABC[N:N+4]}));
57      $elif SSE == 4:
58        vacc${ABC[N:N+4]} = _mm_blendv_ps(vx${ABC[N:N+4]}, vacc${ABC[N:N+4]}, vx${ABC[N:N+4]});
59
60    _mm_storeu_ps(y, vacc${ABC[0:4]});
61    $for N in range(4, BATCH_TILE, 4):
62      _mm_storeu_ps(y + ${N}, vacc${ABC[N:N+4]});
63    y += ${BATCH_TILE};
64  }
65  $if BATCH_TILE > 4:
66    for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
67      $if SSE == 1:
68        __m128 vx = _mm_loadu_ps(x);
69      $else:
70        const __m128 vx = _mm_loadu_ps(x);
71      x += 4;
72
73      $if SSE == 1:
74        __m128 vacc = _mm_max_ps(_mm_setzero_ps(), vx);
75        vx = _mm_min_ps(vx, vzero);
76        vacc = _mm_add_ps(vacc, _mm_mul_ps(vx, vslope));
77      $else:
78        __m128 vacc = _mm_mul_ps(vx, vslope);
79        $if SSE == 2:
80          const __m128 vmask = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
81          vacc = _mm_or_ps(_mm_and_ps(vacc, vmask), _mm_andnot_ps(vmask, vx));
82        $elif SSE == 4:
83          vacc = _mm_blendv_ps(vx, vacc, vx);
84
85      _mm_storeu_ps(y, vacc);
86      y += 4;
87    }
88  if XNN_UNLIKELY(n != 0) {
89    $if SSE == 1:
90      __m128 vx = _mm_loadu_ps(x);
91
92      __m128 vacc = _mm_max_ps(_mm_setzero_ps(), vx);
93      vx = _mm_min_ps(vx, vzero);
94      vacc = _mm_add_ps(vacc, _mm_mul_ps(vx, vslope));
95    $else:
96      const __m128 vx = _mm_loadu_ps(x);
97
98      __m128 vacc = _mm_mul_ps(vx, vslope);
99      $if SSE == 2:
100        const __m128 vmask = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
101        vacc = _mm_or_ps(_mm_and_ps(vacc, vmask), _mm_andnot_ps(vmask, vx));
102      $elif SSE == 4:
103        vacc = _mm_blendv_ps(vx, vacc, vx);
104
105    if (n & (2 * sizeof(float))) {
106      _mm_storel_pi((__m64*) y, vacc);
107      vacc = _mm_movehl_ps(vacc, vacc);
108      y += 2;
109    }
110    if (n & (1 * sizeof(float))) {
111      _mm_store_ss(y, vacc);
112    }
113  }
114}
115