1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$assert NR % 4 == 0
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <wasm_simd128.h>
12
13#include <xnnpack/ppmm.h>
14
15
16void xnn_f32_ppmm_minmax_ukernel_${MR}x${NR}__wasmsimd_${"x86" if X86 else "arm"}_splat(
17  size_t mr,
18  size_t nc,
19  size_t kc,
20  const float*restrict a,
21  const float*restrict w,
22  float*restrict c,
23  size_t cm_stride,
24  size_t cn_stride,
25  const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(mr != 0);
28  assert(mr <= ${MR});
29  assert(nc != 0);
30  assert(kc != 0);
31  assert(kc % sizeof(float) == 0);
32
33  float* c0 = c;
34  $for M in range(1, MR):
35    float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
36    $if M % 2 == 0:
37      if XNN_UNPREDICTABLE(mr <= ${M}) {
38        c${M} = c${M-1};
39      }
40    $elif M + 1 == MR:
41      if XNN_UNPREDICTABLE(mr != ${M+1}) {
42        c${M} = c${M-1};
43      }
44    $else:
45      if XNN_UNPREDICTABLE(mr < ${M+1}) {
46        c${M} = c${M-1};
47      }
48
49  $if not X86:
50    const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
51    const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
52  do {
53    v128_t vacc0x${ABC[0:4]} = wasm_v128_load(w);
54    $for N in range(4, NR, 4):
55      v128_t vacc0x${ABC[N:N+4]} = wasm_v128_load(w + ${N});
56    $for M in range(1, MR):
57      $for N in range(0, NR, 4):
58        v128_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
59    w += ${NR};
60
61    size_t k = kc;
62    do {
63      const v128_t va${ABC[0:4]} = wasm_v128_load(a);
64      $for M in range(4, MR, 4):
65        const v128_t va${ABC[M:M+4]} = wasm_v128_load(a + ${M});
66      a += ${MR};
67
68      const v128_t vb${ABC[0:4]} = wasm_v128_load(w);
69      $for N in range(4, NR, 4):
70        const v128_t vb${ABC[N:N+4]} = wasm_v128_load(w + ${N});
71      w += ${NR};
72
73      $for M in range(MR):
74        $MMMM = str(M) * 4
75        const v128_t va${MMMM} = wasm_v32x4_shuffle(va${ABC[M&-4:4+M&-4]}, va${ABC[M&-4:4+M&-4]}, ${M}, ${M}, ${M}, ${M});
76
77      $for N in range(0, NR, 4):
78        $for M in range(MR):
79          $MMMM = str(M) * 4
80          vacc${M}x${ABC[N:N+4]} = wasm_f32x4_add(vacc${M}x${ABC[N:N+4]}, wasm_f32x4_mul(va${MMMM}, vb${ABC[N:N+4]}));
81
82      k -= sizeof(float);
83    } while (k != 0);
84
85    $if X86:
86      const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
87      $for N in range(0, NR, 4):
88        $for M in range(MR):
89          vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[N:N+4]}, wasm_f32x4_lt(vacc${M}x${ABC[N:N+4]}, vmin));
90
91      const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
92      $for N in range(0, NR, 4):
93        $for M in range(MR):
94          vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vacc${M}x${ABC[N:N+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vmax));
95    $else:
96      $for N in range(0, NR, 4):
97        $for M in range(MR):
98          vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vmin);
99
100      $for N in range(0, NR, 4):
101        $for M in range(MR):
102          vacc${M}x${ABC[N:N+4]} = wasm_f32x4_min(vacc${M}x${ABC[N:N+4]}, vmax);
103
104    if XNN_LIKELY(nc >= ${NR}) {
105      $for M in reversed(range(MR)):
106        wasm_v128_store(c${M}, vacc${M}x${ABC[0:4]});
107        $for N in range(4, NR, 4):
108          wasm_v128_store(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
109
110      a = (const float*) ((uintptr_t) a - kc * ${MR});
111
112      $for M in reversed(range(MR)):
113        c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
114
115      nc -= ${NR};
116    } else {
117      $for LOG2N in reversed(range(NR.bit_length())):
118        $if NR != 1 << LOG2N:
119          if (nc & ${1 << LOG2N}) {
120            $if LOG2N >= 2:
121              $for M in reversed(range(MR)):
122                wasm_v128_store(c${M}, vacc${M}x${ABC[0:4]});
123                $for N in range(4, 1 << LOG2N, 4):
124                  wasm_v128_store(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
125
126              $for M in reversed(range(MR)):
127                $for N in range(0, 1 << (LOG2N - 1), 4):
128                  vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
129
130              $for M in reversed(range(MR)):
131                c${M} += ${1 << LOG2N};
132            $elif LOG2N == 1:
133              $for M in reversed(range(MR)):
134                *((double*) c${M}) = wasm_f64x2_extract_lane(vacc${M}x${ABC[0:4]}, 0);
135
136              $for M in reversed(range(MR)):
137                vacc${M}x${ABC[0:4]} = wasm_v32x4_shuffle(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]}, 2, 3, 2, 3);
138
139              $for M in reversed(range(MR)):
140                c${M} += 2;
141            $elif LOG2N == 0:
142              $for M in reversed(range(MR)):
143                *c${M} = wasm_f32x4_extract_lane(vacc${M}x${ABC[0:4]}, 0);
144          }
145
146      nc = 0;
147    }
148  } while (nc != 0);
149}
150