1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert VARIANT in ["LD64", "LD128", "EXTENDED"]
7$assert MR <= 4
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12#include <xnnpack/gemm.h>
13#include <xnnpack/math.h>
14
15
16$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
17$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
18void xnn_qs8_gemm${GEMM_SUFFIX}_minmax_ukernel_${MR}x4c8__wasmsimd${LOAD_SUFFIX}(
19    size_t mr,
20    size_t nc,
21    size_t kc,
22    const int8_t* restrict a,
23    size_t a_stride,
24    const void* restrict w,
25    int8_t* restrict c,
26    size_t cm_stride,
27    size_t cn_stride,
28    const union xnn_qs8_gemm${GEMM_SUFFIX}_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
29{
30  assert(mr != 0);
31  assert(mr <= ${MR});
32  assert(nc != 0);
33  assert(kc != 0);
34  assert(kc % sizeof(int8_t) == 0);
35  assert(a != NULL);
36  assert(w != NULL);
37  assert(c != NULL);
38
39  kc = round_up_po2(kc, 8);
40  const int8_t* a0 = a;
41  int8_t* c0 = c;
42  $for M in range(1, MR):
43    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
44    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
45    $if M % 2 == 0:
46      if XNN_UNPREDICTABLE(mr <= ${M}) {
47        a${M} = a${M-1};
48        c${M} = c${M-1};
49      }
50    $elif M + 1 == MR:
51      if XNN_UNPREDICTABLE(mr != ${M+1}) {
52        a${M} = a${M-1};
53        c${M} = c${M-1};
54      }
55    $else:
56      if XNN_UNPREDICTABLE(mr < ${M+1}) {
57        a${M} = a${M-1};
58        c${M} = c${M-1};
59      }
60
61  const v128_t vzero = wasm_f64x2_splat(0.0);
62  do {
63    $for N in range(4):
64      v128_t vacc0x${N} = wasm_f32x4_replace_lane(vzero, 0, ((const float*) w)[${N}]);
65    $for M in range(1, MR):
66      $for N in range(4):
67        v128_t vacc${M}x${N} = vacc0x${N};
68    w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
69
70    size_t k = 0;
71    while (k < kc) {
72      $for M in range(MR):
73        const v128_t vxa${M} = wasm_i16x8_load_8x8(a${M});
74        a${M} += 8;
75
76      $if VARIANT == "LD128":
77        $for N in range(0, 4, 2):
78          $if N == 0:
79            const v128_t vb${N}${N+1} = wasm_v128_load(w);
80          $else:
81            const v128_t vb${N}${N+1} = wasm_v128_load((const void*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t)));
82          const v128_t vxb${N} = wasm_i16x8_widen_low_i8x16(vb${N}${N+1});
83          const v128_t vxb${N+1} = wasm_i16x8_widen_high_i8x16(vb${N}${N+1});
84
85          $for M in range(MR):
86            const v128_t vprod${M}x${N} = wasm_i16x8_mul(vxb${N}, vxa${M});
87            vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_widen_low_i16x8(vprod${M}x${N}));
88
89          $for M in range(MR):
90            const v128_t vprod${M}x${N+1} = wasm_i16x8_mul(vxb${N+1}, vxa${M});
91            vacc${M}x${N+1} = wasm_i32x4_add(vacc${M}x${N+1}, wasm_i32x4_widen_low_i16x8(vprod${M}x${N+1}));
92            vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_widen_high_i16x8(vprod${M}x${N}));
93
94          $for M in range(MR):
95            vacc${M}x${N+1} = wasm_i32x4_add(vacc${M}x${N+1}, wasm_i32x4_widen_high_i16x8(vprod${M}x${N+1}));
96      $else:
97        $for N in range(4):
98          $if VARIANT == "LD64":
99            $if N == 0:
100              const v128_t vxb${N} = wasm_i16x8_load_8x8(w);
101            $else:
102              const v128_t vxb${N} = wasm_i16x8_load_8x8((const void*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t)));
103          $elif VARIANT == "EXTENDED":
104            $if N == 0:
105              const v128_t vxb${N} = wasm_v128_load(w);
106            $else:
107              const v128_t vxb${N} = wasm_v128_load((const void*) ((uintptr_t) w + ${N * 8} * sizeof(int16_t)));
108
109          $for M in range(MR):
110            const v128_t vprod${M}x${N} = wasm_i16x8_mul(vxa${M}, vxb${N});
111            vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_widen_low_i16x8(vprod${M}x${N}));
112            vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_widen_high_i16x8(vprod${M}x${N}));
113
114      $if VARIANT == "EXTENDED":
115        w = (const void*) ((uintptr_t) w + 32 * sizeof(int16_t));
116      $else:
117        w = (const void*) ((uintptr_t) w + 32 * sizeof(int8_t));
118      k += 8 * sizeof(int8_t);
119    }
120
121    $for M in range(MR):
122      const v128_t vacc${M}x02 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 2, 6, 3, 7));
123      const v128_t vacc${M}x13 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 2, 6, 3, 7));
124
125    $for M in range(MR):
126      v128_t vacc${M}x0123 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 2, 6, 3, 7));
127
128    $for M in range(MR):
129      const v128_t vsign${M}x0123 = wasm_i32x4_lt(vacc${M}x0123, vzero);
130
131    $for M in range(MR):
132      const v128_t vacc${M}x01 = wasm_v32x4_shuffle(vacc${M}x0123, vsign${M}x0123, 0, 4, 1, 5);
133
134    const v128_t vmultiplier = wasm_v128_load(params->wasmsimd.multiplier);
135    const v128_t vrounding = wasm_v128_load(params->wasmsimd.rounding);
136    $for M in range(MR):
137      const v128_t vprod${M}x01 = wasm_i64x2_add(wasm_i64x2_mul(vacc${M}x01, vmultiplier), vrounding);
138      const v128_t vacc${M}x23 = wasm_v32x4_shuffle(vacc${M}x0123, vsign${M}x0123, 2, 6, 3, 7);
139
140    $for M in range(MR):
141      const v128_t vprod${M}x23 = wasm_i64x2_add(wasm_i64x2_mul(vacc${M}x23, vmultiplier), vrounding);
142
143    $for M in range(MR):
144      const v128_t vq31prod${M}x0123 = wasm_v32x4_shuffle(vprod${M}x01, vprod${M}x23, 1, 3, 5, 7);
145
146    const v128_t vremainder_mask = wasm_v128_load(params->wasmsimd.remainder_mask);
147    $for M in range(MR):
148      const v128_t vrem${M}x0123 = wasm_i32x4_add(wasm_v128_and(vq31prod${M}x0123, vremainder_mask), wasm_i32x4_lt(vq31prod${M}x0123, vzero));
149
150    const v128_t vthreshold = wasm_v128_load(params->wasmsimd.remainder_threshold);
151    const int32_t vshift = params->wasmsimd.shift;
152    $for M in range(MR):
153      vacc${M}x0123 = wasm_i32x4_sub(wasm_i32x4_shr(vq31prod${M}x0123, vshift), wasm_i32x4_gt(vrem${M}x0123, vthreshold));
154
155    const v128_t voutput_zero_point = wasm_v128_load(params->wasmsimd.output_zero_point);
156    $for M in range(0, MR, 2):
157      v128_t vacc${M}${min(M+1, MR-1)}x0123 = wasm_i16x8_add_saturate(wasm_i16x8_narrow_i32x4(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);
158
159    $if MR > 2:
160      v128_t vout = wasm_i8x16_narrow_i16x8(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
161    $else:
162      v128_t vout = wasm_i8x16_narrow_i16x8(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
163
164    const v128_t voutput_min = wasm_v128_load(params->wasmsimd.output_min);
165    vout = wasm_i8x16_max(vout, voutput_min);
166
167    const v128_t voutput_max = wasm_v128_load(params->wasmsimd.output_max);
168    vout = wasm_i8x16_min(vout, voutput_max);
169
170    if (nc >= 4) {
171      $for M in range(MR):
172        *((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M});
173
174      $for M in range(MR):
175        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
176
177      $for M in range(MR):
178        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
179
180      nc -= 4;
181    } else {
182      if (nc & 2) {
183        $for M in range(MR):
184          *((uint16_t*) c${M}) = (uint16_t) wasm_i16x8_extract_lane(vout, ${M * 2});
185          c${M} += 2;
186        vout = wasm_u32x4_shr(vout, 16);
187      }
188      if (nc & 1) {
189        $for M in range(MR):
190          *c${M} = (int8_t) wasm_i8x16_extract_lane(vout, ${M * 4});
191      }
192
193      nc = 0;
194    }
195  } while (nc != 0);
196}
197