1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ADD", "DIV", "RDIV", "MAX", "MIN", "MUL", "SUB", "RSUB", "SQRDIFF"] 10$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"] 11#include <assert.h> 12 13#include <wasm_simd128.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/vbinary.h> 17 18 19$WASM_F32X4_OP = { 20$ "ADD": lambda x: "wasm_f32x4_add(%s, vb)" % x, 21$ "DIV": lambda x: "wasm_f32x4_div(%s, vb)" % x, 22$ "RDIV": lambda x: "wasm_f32x4_div(vb, %s)" % x, 23$ "MAX": lambda x: "wasm_f32x4_max(%s, vb)" % x, 24$ "MIN": lambda x: "wasm_f32x4_min(%s, vb)" % x, 25$ "MUL": lambda x: "wasm_f32x4_mul(%s, vb)" % x, 26$ "SUB": lambda x: "wasm_f32x4_sub(%s, vb)" % x, 27$ "RSUB": lambda x: "wasm_f32x4_sub(vb, %s)" % x, 28$ "SQRDIFF": lambda x: "wasm_f32x4_sub(%s, vb)" % x, 29$}[OP] 30$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"] 31$ARCH_SUFFIX = "" if ACTIVATION in ["LINEAR", "RELU"] and OP not in ["MIN", "MAX"] else "_x86" if X86 else "_arm" 32$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower()) 33$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION] 34void xnn_f32_v${OP.lower()}c${ACTIVATION_SUFFIX}_ukernel__wasmsimd${ARCH_SUFFIX}_x${BATCH_TILE}( 35 size_t n, 36 const float* a, 37 const float* b, 38 float* y, 39 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN 40{ 41 assert(n != 0); 42 assert(n % sizeof(float) == 0); 43 assert(a != NULL); 44 assert(b != NULL); 45 assert(y != NULL); 46 47 $if ACTIVATION == "MINMAX": 48 const v128_t vy_min = wasm_v32x4_load_splat(¶ms->scalar.min); 49 const v128_t vy_max = wasm_v32x4_load_splat(¶ms->scalar.max); 50 $elif ACTIVATION == "RELU": 51 const v128_t vzero = wasm_f32x4_splat(0.0f); 52 const v128_t vb = wasm_v32x4_load_splat(b); 53 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 54 const v128_t va${ABC[0:4]} = wasm_v128_load(a); 55 $for N in range(4, BATCH_TILE, 4): 56 const v128_t va${ABC[N:N+4]} = wasm_v128_load(a + ${N}); 57 a += ${BATCH_TILE}; 58 59 $if OP == "MIN" and X86: 60 $for N in range(0, BATCH_TILE, 4): 61 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_lt(va${ABC[N:N+4]}, vb); 62 63 $for N in range(0, BATCH_TILE, 4): 64 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(va${ABC[N:N+4]}, vb, vm${ABC[N:N+4]}); 65 $elif OP == "MAX" and X86: 66 $for N in range(0, BATCH_TILE, 4): 67 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_le(va${ABC[N:N+4]}, vb); 68 69 $for N in range(0, BATCH_TILE, 4): 70 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(vb, va${ABC[N:N+4]}, vm${ABC[N:N+4]}); 71 $else: 72 $for N in range(0, BATCH_TILE, 4): 73 v128_t vy${ABC[N:N+4]} = ${WASM_F32X4_OP("va" + ABC[N:N+4])}; 74 75 $if OP == "SQRDIFF": 76 $for N in range(0, BATCH_TILE, 4): 77 vy${ABC[N:N+4]} = wasm_f32x4_mul(vy${ABC[N:N+4]}, vy${ABC[N:N+4]}); 78 79 $if ACTIVATION == "MINMAX": 80 $if X86: 81 $for N in range(0, BATCH_TILE, 4): 82 const v128_t vltmask${ABC[N:N+4]} = wasm_f32x4_lt(vy${ABC[N:N+4]}, vy_min); 83 84 $for N in range(0, BATCH_TILE, 4): 85 const v128_t vngtmask${ABC[N:N+4]} = wasm_f32x4_le(vy${ABC[N:N+4]}, vy_max); 86 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy_min, vy${ABC[N:N+4]}, vltmask${ABC[N:N+4]}); 87 88 $for N in range(0, BATCH_TILE, 4): 89 vy${ABC[N:N+4]} = wasm_v128_bitselect(vy${ABC[N:N+4]}, vy_max, vngtmask${ABC[N:N+4]}); 90 $else: 91 $for N in range(0, BATCH_TILE, 4): 92 vy${ABC[N:N+4]} = wasm_f32x4_max(vy${ABC[N:N+4]}, vy_min); 93 94 $for N in range(0, BATCH_TILE, 4): 95 vy${ABC[N:N+4]} = wasm_f32x4_min(vy${ABC[N:N+4]}, vy_max); 96 $elif ACTIVATION == "RELU": 97 $for N in range(0, BATCH_TILE, 4): 98 vy${ABC[N:N+4]} = wasm_i32x4_max(vy${ABC[N:N+4]}, vzero); 99 100 wasm_v128_store(y, vy${ABC[0:4]}); 101 $for N in range(4, BATCH_TILE, 4): 102 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]}); 103 y += ${BATCH_TILE}; 104 } 105 $if BATCH_TILE > 4: 106 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 107 const v128_t va = wasm_v128_load(a); 108 a += 4; 109 110 $if OP == "MIN" and X86: 111 const v128_t vm = wasm_f32x4_lt(va, vb); 112 v128_t vy = wasm_v128_bitselect(va, vb, vm); 113 $elif OP == "MAX" and X86: 114 const v128_t vm = wasm_f32x4_le(va, vb); 115 v128_t vy = wasm_v128_bitselect(vb, va, vm); 116 $else: 117 v128_t vy = ${WASM_F32X4_OP("va")}; 118 $if OP == "SQRDIFF": 119 vy = wasm_f32x4_mul(vy, vy); 120 121 $if ACTIVATION == "MINMAX": 122 $if X86: 123 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min); 124 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max); 125 vy = wasm_v128_bitselect(vy_min, vy, vltmask); 126 vy = wasm_v128_bitselect(vy, vy_max, vngtmask); 127 $else: 128 vy = wasm_f32x4_max(vy, vy_min); 129 vy = wasm_f32x4_min(vy, vy_max); 130 $elif ACTIVATION == "RELU": 131 vy = wasm_i32x4_max(vy, vzero); 132 133 wasm_v128_store(y, vy); 134 y += 4; 135 } 136 if XNN_UNLIKELY(n != 0) { 137 const v128_t va = wasm_v128_load(a); 138 139 $if OP == "MIN" and X86: 140 const v128_t vm = wasm_f32x4_lt(va, vb); 141 v128_t vy = wasm_v128_bitselect(va, vb, vm); 142 $elif OP == "MAX" and X86: 143 const v128_t vm = wasm_f32x4_le(va, vb); 144 v128_t vy = wasm_v128_bitselect(vb, va, vm); 145 $else: 146 v128_t vy = ${WASM_F32X4_OP("va")}; 147 $if OP == "SQRDIFF": 148 vy = wasm_f32x4_mul(vy, vy); 149 150 $if ACTIVATION == "MINMAX": 151 $if X86: 152 const v128_t vltmask = wasm_f32x4_lt(vy, vy_min); 153 const v128_t vngtmask = wasm_f32x4_le(vy, vy_max); 154 vy = wasm_v128_bitselect(vy_min, vy, vltmask); 155 vy = wasm_v128_bitselect(vy, vy_max, vngtmask); 156 $else: 157 vy = wasm_f32x4_max(vy, vy_min); 158 vy = wasm_f32x4_min(vy, vy_max); 159 $elif ACTIVATION == "RELU": 160 vy = wasm_i32x4_max(vy, vzero); 161 162 if (n & (2 * sizeof(float))) { 163 *((double*) y) = wasm_f64x2_extract_lane(vy, 0); 164 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3); 165 y += 2; 166 } 167 if (n & (1 * sizeof(float))) { 168 *y = wasm_f32x4_extract_lane(vy, 0); 169 } 170 } 171} 172