1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8#include <assert.h>
9
10#include <immintrin.h>
11
12#include <xnnpack/spmm.h>
13
14
15void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__sse${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
16    size_t mc,
17    size_t nc,
18    const float*restrict input,
19    const float*restrict weights,
20    const int32_t*restrict widx_dmap,
21    const uint32_t*restrict nidx_nnzmap,
22    float*restrict output,
23    size_t output_stride,
24    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
25{
26  assert(mc != 0);
27  assert(mc % sizeof(float) == 0);
28  assert(nc != 0);
29
30  const __m128 vmin = _mm_load_ps(params->sse.min);
31  const __m128 vmax = _mm_load_ps(params->sse.max);
32  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
33  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
34    const float*restrict w = weights;
35    const int32_t* dmap = widx_dmap;
36    const uint32_t* nnzmap = nidx_nnzmap;
37    size_t n = nc;
38    do {
39      uint32_t nnz = *nnzmap++;
40      $if UNROLL > 1:
41        __m128 vacc0123x0 = _mm_load1_ps(w);
42        w += 1;
43        $for K in range(1, UNROLL):
44          __m128 vacc0123x${K} = _mm_setzero_ps();
45        $for M in range(4, MR, 4):
46          __m128 vacc${ABC[M:M+4]}x0 = vacc0123x0;
47          $for K in range(1, UNROLL):
48            __m128 vacc${ABC[M:M+4]}x${K} = _mm_setzero_ps();
49        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
50          $for K in range(UNROLL):
51            const intptr_t diff${K} = dmap[${K}];
52          dmap += ${UNROLL};
53          $for K in range(UNROLL):
54            const __m128 vi0123x${K} = _mm_loadu_ps(input);
55            $for M in range(4, MR, 4):
56              const __m128 vi${ABC[M:M+4]}x${K} = _mm_loadu_ps(input + ${M});
57            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K});
58            const __m128 vw${K} = _mm_load1_ps(w);
59            w += 1;
60            $for M in range(0, MR, 4):
61              vacc${ABC[M:M+4]}x${K} = _mm_add_ps(vacc${ABC[M:M+4]}x${K}, _mm_mul_ps(vi${ABC[M:M+4]}x${K}, vw${K}));
62        }
63        $for M in range(0, MR, 4):
64          __m128 vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
65        $for K in range(1, UNROLL):
66          $for M in range(0, MR, 4):
67            vacc${ABC[M:M+4]} = _mm_add_ps(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
68      $else:
69        __m128 vacc0123 = _mm_load1_ps(w); w += 1;
70        $for M in range(4, MR, 4):
71          __m128 vacc${ABC[M:M+4]} = vacc0123;
72      if XNN_LIKELY(nnz != 0) {
73        do {
74          const intptr_t diff = *dmap++;
75          const __m128 vi0123 = _mm_loadu_ps(input);
76          $for M in range(4, MR, 4):
77            const __m128 vi${ABC[M:M+4]} = _mm_loadu_ps(input + ${M});
78          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
79          const __m128 vw = _mm_load1_ps(w); w += 1;
80          $for M in range(0, MR, 4):
81            vacc${ABC[M:M+4]} = _mm_add_ps(vacc${ABC[M:M+4]}, _mm_mul_ps(vi${ABC[M:M+4]}, vw));
82        } while (--nnz != 0);
83      }
84      $for M in range(0, MR, 4):
85        __m128 vout${ABC[M:M+4]} = _mm_min_ps(vacc${ABC[M:M+4]}, vmax);
86      $for M in range(0, MR, 4):
87        vout${ABC[M:M+4]} = _mm_max_ps(vout${ABC[M:M+4]}, vmin);
88      _mm_storeu_ps(output, vout0123);
89      $for M in range(4, MR, 4):
90        _mm_storeu_ps(output + ${M}, vout${ABC[M:M+4]});
91      output = (float*restrict) ((uintptr_t) output + output_stride);
92    } while (--n != 0);
93    output = (float*restrict) ((uintptr_t) output - output_decrement);
94    input += ${MR};
95    mc -= ${MR} * sizeof(float);
96  }
97  if XNN_UNLIKELY(mc != 0) {
98    $for LOG2M in reversed(range((MR - 1).bit_length())):
99      $SUBMR = 1 << LOG2M
100      $if SUBMR * 2 >= MR:
101        output_decrement += ${MR - SUBMR} * sizeof(float);
102      $else:
103        output_decrement += ${SUBMR} * sizeof(float);
104      if (mc & (${SUBMR} * sizeof(float))) {
105        const float*restrict w = weights;
106        const int32_t* dmap = widx_dmap;
107        const uint32_t* nnzmap = nidx_nnzmap;
108        size_t n = nc;
109        do {
110          uint32_t nnz = *nnzmap++;
111          $if SUBMR == 1:
112            __m128 vacc0 = _mm_load_ss(w); w += 1;
113          $elif SUBMR == 2:
114            __m128 vacc01 = _mm_load_ss(w); w += 1;
115            vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
116          $else:
117            __m128 vacc0123 = _mm_load1_ps(w); w += 1;
118          $for M in range(4, SUBMR, 4):
119            __m128 vacc${ABC[M:M+4]} = vacc0123;
120          if XNN_LIKELY(nnz != 0) {
121            do {
122              const intptr_t diff = *dmap++;
123              $if SUBMR >= 4:
124                const __m128 vi0123 = _mm_loadu_ps(input);
125              $elif SUBMR == 2:
126                const __m128 vi01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) input);
127              $elif SUBMR == 1:
128                const __m128 vi0 = _mm_load_ss(input);
129              $for M in range(4, SUBMR, 4):
130                const __m128 vi${ABC[M:M+4]} = _mm_loadu_ps(input + ${M});
131              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
132              $if SUBMR >= 4:
133                const __m128 vw = _mm_load1_ps(w); w += 1;
134              $elif SUBMR == 2:
135                __m128 vw = _mm_load_ss(w); w += 1;
136                vw = _mm_unpacklo_ps(vw, vw);
137              $else:
138                const __m128 vw = _mm_load_ss(w); w += 1;
139              $if SUBMR == 1:
140                vacc${ABC[0]} = _mm_add_ss(vacc${ABC[0]}, _mm_mul_ss(vi${ABC[0]}, vw));
141              $else:
142                $for M in range(0, SUBMR, 4):
143                  vacc${ABC[M:min(M+4,SUBMR)]} = _mm_add_ps(vacc${ABC[M:min(M+4,SUBMR)]}, _mm_mul_ps(vi${ABC[M:min(M+4,SUBMR)]}, vw));
144            } while (--nnz != 0);
145          }
146          $if SUBMR == 1:
147            __m128 vout${ABC[0]} = _mm_min_ss(vacc${ABC[0]}, vmax);
148            vout${ABC[0]} = _mm_max_ss(vout${ABC[0]}, vmin);
149          $else:
150            $for M in range(0, SUBMR, 4):
151              __m128 vout${ABC[M:min(M+4,SUBMR)]} = _mm_min_ps(vacc${ABC[M:min(M+4,SUBMR)]}, vmax);
152            $for M in range(0, SUBMR, 4):
153              vout${ABC[M:min(M+4,SUBMR)]} = _mm_max_ps(vout${ABC[M:min(M+4,SUBMR)]}, vmin);
154          $if SUBMR >= 4:
155            _mm_storeu_ps(output, vout0123);
156          $elif SUBMR == 2:
157            _mm_storel_pi((__m64*) output, vout01);
158          $elif SUBMR == 1:
159            _mm_store_ss(output, vout0);
160          $for M in range(4, SUBMR, 4):
161            _mm_storeu_ps(output + ${M}, vout${ABC[M:M+4]});
162          output = (float*restrict) ((uintptr_t) output + output_stride);
163        } while (--n != 0);
164        output = (float*restrict) ((uintptr_t) output - output_decrement);
165        input += ${SUBMR};
166      }
167  }
168}
169