1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8#include <assert.h>
9
10#include <xnnpack/common.h>
11#include <xnnpack/math.h>
12#include <xnnpack/hswish.h>
13
14
15$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
16$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
17void xnn_f32_hswish_ukernel__${"wasm" if WASM else "scalar"}_x${BATCH_TILE}(
18    size_t n,
19    const float* x,
20    float* y,
21    const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  assert(n != 0);
24  assert(n % sizeof(float) == 0);
25
26  const float vsixth = params->scalar.sixth;
27  const float vthree = params->scalar.three;
28  const float vsix = params->scalar.six;
29  const float vzero = 0.0f;
30  assert(vthree == 3.0f);
31  assert(vsix == 6.0f);
32
33  $if BATCH_TILE > 1:
34    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
35      $for N in range(BATCH_TILE):
36        float vx${ABC[N]} = x[${N}];
37      x += ${BATCH_TILE};
38
39      $for N in range(BATCH_TILE):
40        float vacc${ABC[N]} = vx${ABC[N]} + vthree;
41        vx${ABC[N]} *= vsixth;
42
43      $for N in range(BATCH_TILE):
44        vacc${ABC[N]} = ${MAX_F32}(vacc${ABC[N]}, vzero);
45
46      $for N in range(BATCH_TILE):
47        vacc${ABC[N]} = ${MIN_F32}(vacc${ABC[N]}, vsix);
48
49      $for N in range(BATCH_TILE):
50        vacc${ABC[N]} *= vx${ABC[N]};
51
52      $for N in range(BATCH_TILE):
53        y[${N}] = vacc${ABC[N]};
54      y += ${BATCH_TILE};
55    }
56    if XNN_UNLIKELY(n != 0) {
57      $if BATCH_TILE > 2:
58        do {
59          float vx = *x++;
60          float vacc = vx + vthree;
61          vx *= vsixth;
62          vacc = ${MAX_F32}(vacc, vzero);
63          vacc = ${MIN_F32}(vacc, vsix);
64          vacc *= vx;
65          *y++ = vacc;
66          n -= sizeof(float);
67        } while (n != 0);
68      $else:
69        float vx = *x;
70        float vacc = vx + vthree;
71        vx *= vsixth;
72        vacc = ${MAX_F32}(vacc, vzero);
73        vacc = ${MIN_F32}(vacc, vsix);
74        vacc *= vx;
75        *y = vacc;
76    }
77  $else:
78    for (; n >= sizeof(float); n -= sizeof(float)) {
79      float vx = *x++;
80      float vacc = vx + vthree;
81      vx *= vsixth;
82      vacc = ${MAX_F32}(vacc, vzero);
83      vacc = ${MIN_F32}(vacc, vsix);
84      vacc *= vx;
85      *y++ = vacc;
86    }
87}
88