1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"]
10$assert ACTIVATION in ["LINEAR", "MINMAX"]
11#include <assert.h>
12
13#include <immintrin.h>
14
15#include <xnnpack/common.h>
16#include <xnnpack/vbinary.h>
17
18
19static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};
20
21$_MM256_OP_PS = {
22$  "ADD": lambda x, y: "_mm256_add_ps(%s, %s)" % (x, y),
23$  "DIV": lambda x, y: "_mm256_div_ps(%s, %s)" % (x, y),
24$  "MAX": lambda x, y: "_mm256_max_ps(%s, %s)" % (x, y),
25$  "MIN": lambda x, y: "_mm256_min_ps(%s, %s)" % (x, y),
26$  "MUL": lambda x, y: "_mm256_mul_ps(%s, %s)" % (x, y),
27$  "SUB": lambda x, y: "_mm256_sub_ps(%s, %s)" % (x, y),
28$  "SQRDIFF": lambda x, y: "_mm256_sub_ps(%s, %s)" % (x, y),
29$}[OP]
30$SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION]
31$PARAMS = {"LINEAR": "xnn_f32_default_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
32void xnn_f32_v${OP.lower()}${SUFFIX}_ukernel__avx_x${BATCH_TILE}(
33    size_t n,
34    const float* a,
35    const float* b,
36    float* y,
37    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
38{
39  assert(n != 0);
40  assert(n % sizeof(float) == 0);
41  assert(a != NULL);
42  assert(b != NULL);
43  assert(y != NULL);
44
45  $if ACTIVATION == "MINMAX":
46    const __m256 vy_min = _mm256_broadcast_ps((const __m128*) params->sse.min);
47    const __m256 vy_max = _mm256_broadcast_ps((const __m128*) params->sse.max);
48
49  for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
50    const __m256 va${ABC[0:8]} = _mm256_loadu_ps(a);
51    $for N in range(8, BATCH_TILE, 8):
52      const __m256 va${ABC[N:N+8]} = _mm256_loadu_ps(a + ${N});
53    a += ${BATCH_TILE};
54
55    const __m256 vb${ABC[0:8]} = _mm256_loadu_ps(b);
56    $for N in range(8, BATCH_TILE, 8):
57      const __m256 vb${ABC[N:N+8]} = _mm256_loadu_ps(b + ${N});
58    b += ${BATCH_TILE};
59
60    $for N in range(0, BATCH_TILE, 8):
61      __m256 vy${ABC[N:N+8]} = ${_MM256_OP_PS("va" + ABC[N:N+8], "vb" + ABC[N:N+8])};
62
63    $if OP == "SQRDIFF":
64      $for N in range(0, BATCH_TILE, 8):
65        vy${ABC[N:N+8]} = _mm256_mul_ps(vy${ABC[N:N+8]}, vy${ABC[N:N+8]});
66
67    $if ACTIVATION == "MINMAX":
68      $for N in range(0, BATCH_TILE, 8):
69        vy${ABC[N:N+8]} = _mm256_max_ps(vy${ABC[N:N+8]}, vy_min);
70
71      $for N in range(0, BATCH_TILE, 8):
72        vy${ABC[N:N+8]} = _mm256_min_ps(vy${ABC[N:N+8]}, vy_max);
73
74    _mm256_storeu_ps(y, vy${ABC[0:8]});
75    $for N in range(8, BATCH_TILE, 8):
76      _mm256_storeu_ps(y + ${N}, vy${ABC[N:N+8]});
77    y += ${BATCH_TILE};
78  }
79  $if BATCH_TILE > 8:
80    for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
81      const __m256 va = _mm256_loadu_ps(a);
82      a += 8;
83
84      const __m256 vb = _mm256_loadu_ps(b);
85      b += 8;
86
87      __m256 vy = ${_MM256_OP_PS("va", "vb")};
88      $if OP == "SQRDIFF":
89        vy = _mm256_mul_ps(vy, vy);
90      $if ACTIVATION == "MINMAX":
91        vy = _mm256_max_ps(vy, vy_min);
92        vy = _mm256_min_ps(vy, vy_max);
93      _mm256_storeu_ps(y, vy);
94      y += 8;
95    }
96  if XNN_UNLIKELY(n != 0) {
97    assert(n >= 1 * sizeof(float));
98    assert(n <= 7 * sizeof(float));
99    __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n));
100
101    const __m256 va = _mm256_maskload_ps(a, vmask);
102    const __m256 vb = _mm256_maskload_ps(b, vmask);
103
104    __m256 vy = ${_MM256_OP_PS("va", "vb")};
105    $if OP == "SQRDIFF":
106      vy = _mm256_mul_ps(vy, vy);
107    $if ACTIVATION == "MINMAX":
108      vy = _mm256_max_ps(vy, vy_min);
109      vy = _mm256_min_ps(vy, vy_max);
110
111    // _mm256_maskstore_ps(y, vmask, vy) could be used here, but triggers msan failures (probably an msan bug).
112    __m128 vy_lo = _mm256_castps256_ps128(vy);
113    if (n & (4 * sizeof(float))) {
114      _mm_storeu_ps(y, vy_lo);
115      vy_lo = _mm256_extractf128_ps(vy, 1);
116      y += 4;
117    }
118    if (n & (2 * sizeof(float))) {
119      _mm_storel_pi((__m64*) y, vy_lo);
120      vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
121      y += 2;
122    }
123    if (n & (1 * sizeof(float))) {
124      _mm_store_ss(y, vy_lo);
125    }
126  }
127}
128