1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"] 10$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"] 11#include <assert.h> 12 13#include <wasm_simd128.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/vbinary.h> 17 18 19$WASM_F32X4_OP = { 20$ "ADD": "wasm_f32x4_add", 21$ "DIV": "wasm_f32x4_div", 22$ "MAX": "wasm_f32x4_max", 23$ "MIN": "wasm_f32x4_min", 24$ "MUL": "wasm_f32x4_mul", 25$ "SUB": "wasm_f32x4_sub", 26$ "SQRDIFF": "wasm_f32x4_sub", 27$}[OP] 28$ARCH_SUFFIX = "" if ACTIVATION in ["LINEAR", "RELU"] and OP not in ["MIN", "MAX"] else "_x86" if X86 else "_arm" 29$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower()) 30$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION] 31void xnn_f32_v${OP.lower()}${ACTIVATION_SUFFIX}_ukernel__wasmsimd${ARCH_SUFFIX}_x${BATCH_TILE}( 32 size_t n, 33 const float* a, 34 const float* b, 35 float* y, 36 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN 37{ 38 assert(n != 0); 39 assert(n % sizeof(float) == 0); 40 assert(a != NULL); 41 assert(b != NULL); 42 assert(y != NULL); 43 44 $if ACTIVATION == "MINMAX": 45 const v128_t vy_min = wasm_v32x4_load_splat(¶ms->scalar.min); 46 const v128_t vy_max = wasm_v32x4_load_splat(¶ms->scalar.max); 47 $elif ACTIVATION == "RELU": 48 const v128_t vzero = wasm_f32x4_splat(0.0f); 49 50 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 51 const v128_t va${ABC[0:4]} = wasm_v128_load(a); 52 $for N in range(4, BATCH_TILE, 4): 53 const v128_t va${ABC[N:N+4]} = wasm_v128_load(a + ${N}); 54 a += ${BATCH_TILE}; 55 56 const v128_t vb${ABC[0:4]} = wasm_v128_load(b); 57 $for N in range(4, BATCH_TILE, 4): 58 const v128_t vb${ABC[N:N+4]} = wasm_v128_load(b + ${N}); 59 b += ${BATCH_TILE}; 60 61 $if OP == "MIN" and X86: 62 $for N in range(0, BATCH_TILE, 4): 63 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_lt(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 64 65 $for N in range(0, BATCH_TILE, 4): 66 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(va${ABC[N:N+4]}, vb${ABC[N:N+4]}, vm${ABC[N:N+4]}); 67 $elif OP == "MAX" and X86: 68 $for N in range(0, BATCH_TILE, 4): 69 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_le(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 70 71 $for N in range(0, BATCH_TILE, 4): 72 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(vb${ABC[N:N+4]}, va${ABC[N:N+4]}, vm${ABC[N:N+4]}); 73 $else: 74 $for N in range(0, BATCH_TILE, 4): 75 v128_t vy${ABC[N:N+4]} = ${WASM_F32X4_OP}(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 76 77 $if OP == "SQRDIFF": 78 $for N in range(0, BATCH_TILE, 4): 79 vy${ABC[N:N+4]} = wasm_f32x4_mul(vy${ABC[N:N+4]}, vy${ABC[N:N+4]}); 80 81 $if ACTIVATION == "MINMAX": 82 $if X86: 83 $for N in range(0, BATCH_TILE, 4): 84 const v128_t vltmask${ABC[N:N+4]} = wasm_f32x4_lt(vy${ABC[N:N+4]}, vy_min); 85 86 $for N in range(0, BATCH_TILE, 4): 87 const v128_t vngtmask${ABC[N:N+4]} = wasm_f32x4_le(vy${ABC[N:N+4]}, vy_max); 88 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy_min, vy${ABC[N:N+4]}, vltmask${ABC[N:N+4]}); 89 90 $for N in range(0, BATCH_TILE, 4): 91 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy${ABC[N:N+4]}, vy_max, vngtmask${ABC[N:N+4]}); 92 $else: 93 $for N in range(0, BATCH_TILE, 4): 94 vy${ABC[N:N+4]} = wasm_f32x4_max(vy${ABC[N:N+4]}, vy_min); 95 96 $for N in range(0, BATCH_TILE, 4): 97 vy${ABC[N:N+4]} = wasm_f32x4_min(vy${ABC[N:N+4]}, vy_max); 98 $elif ACTIVATION == "RELU": 99 $for N in range(0, BATCH_TILE, 4): 100 vy${ABC[N:N+4]} = wasm_i32x4_max(vy${ABC[N:N+4]}, vzero); 101 102 wasm_v128_store(y, vy${ABC[0:4]}); 103 $for N in range(4, BATCH_TILE, 4): 104 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]}); 105 y += ${BATCH_TILE}; 106 } 107 $if BATCH_TILE > 4: 108 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 109 const v128_t va = wasm_v128_load(a); 110 a += 4; 111 112 const v128_t vb = wasm_v128_load(b); 113 b += 4; 114 115 $if OP == "MIN" and X86: 116 const v128_t vm = wasm_f32x4_lt(va, vb); 117 v128_t vy = wasm_v128_bitselect(va, vb, vm); 118 $elif OP == "MAX" and X86: 119 const v128_t vm = wasm_f32x4_le(va, vb); 120 v128_t vy = wasm_v128_bitselect(vb, va, vm); 121 $else: 122 v128_t vy = ${WASM_F32X4_OP}(va, vb); 123 $if OP == "SQRDIFF": 124 vy = wasm_f32x4_mul(vy, vy); 125 126 $if ACTIVATION == "MINMAX": 127 $if X86: 128 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min); 129 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max); 130 vy = wasm_v128_bitselect(vy_min, vy, vltmask); 131 vy = wasm_v128_bitselect(vy, vy_max, vngtmask); 132 $else: 133 vy = wasm_f32x4_max(vy, vy_min); 134 vy = wasm_f32x4_min(vy, vy_max); 135 $elif ACTIVATION == "RELU": 136 vy = wasm_i32x4_max(vy, vzero); 137 138 wasm_v128_store(y, vy); 139 y += 4; 140 } 141 if XNN_UNLIKELY(n != 0) { 142 const v128_t va = wasm_v128_load(a); 143 const v128_t vb = wasm_v128_load(b); 144 145 $if OP == "MIN" and X86: 146 const v128_t vm = wasm_f32x4_lt(va, vb); 147 v128_t vy = wasm_v128_bitselect(va, vb, vm); 148 $elif OP == "MAX" and X86: 149 const v128_t vm = wasm_f32x4_le(va, vb); 150 v128_t vy = wasm_v128_bitselect(vb, va, vm); 151 $else: 152 v128_t vy = ${WASM_F32X4_OP}(va, vb); 153 $if OP == "SQRDIFF": 154 vy = wasm_f32x4_mul(vy, vy); 155 156 $if ACTIVATION == "MINMAX": 157 $if X86: 158 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min); 159 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max); 160 vy = wasm_v128_bitselect(vy_min, vy, vltmask); 161 vy = wasm_v128_bitselect(vy, vy_max, vngtmask); 162 $else: 163 vy = wasm_f32x4_max(vy, vy_min); 164 vy = wasm_f32x4_min(vy, vy_max); 165 $elif ACTIVATION == "RELU": 166 vy = wasm_i32x4_max(vy, vzero); 167 168 if (n & (2 * sizeof(float))) { 169 *((double*) y) = wasm_f64x2_extract_lane(vy, 0); 170 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3); 171 y += 2; 172 } 173 if (n & (1 * sizeof(float))) { 174 *y = wasm_f32x4_extract_lane(vy, 0); 175 } 176 } 177} 178