1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/sse.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <immintrin.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_32x1__sse(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x1__sse(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const __m128 vmin = _mm_load_ps(params->sse.min);
33 const __m128 vmax = _mm_load_ps(params->sse.max);
34 size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35 while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 do {
41 uint32_t nnz = *nnzmap++;
42 __m128 vacc0123 = _mm_load1_ps(w); w += 1;
43 __m128 vacc4567 = vacc0123;
44 __m128 vacc89AB = vacc0123;
45 __m128 vaccCDEF = vacc0123;
46 __m128 vaccGHIJ = vacc0123;
47 __m128 vaccKLMN = vacc0123;
48 __m128 vaccOPQR = vacc0123;
49 __m128 vaccSTUV = vacc0123;
50 if XNN_LIKELY(nnz != 0) {
51 do {
52 const intptr_t diff = *dmap++;
53 const __m128 vi0123 = _mm_loadu_ps(input);
54 const __m128 vi4567 = _mm_loadu_ps(input + 4);
55 const __m128 vi89AB = _mm_loadu_ps(input + 8);
56 const __m128 viCDEF = _mm_loadu_ps(input + 12);
57 const __m128 viGHIJ = _mm_loadu_ps(input + 16);
58 const __m128 viKLMN = _mm_loadu_ps(input + 20);
59 const __m128 viOPQR = _mm_loadu_ps(input + 24);
60 const __m128 viSTUV = _mm_loadu_ps(input + 28);
61 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
62 const __m128 vw = _mm_load1_ps(w); w += 1;
63 vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
64 vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
65 vacc89AB = _mm_add_ps(vacc89AB, _mm_mul_ps(vi89AB, vw));
66 vaccCDEF = _mm_add_ps(vaccCDEF, _mm_mul_ps(viCDEF, vw));
67 vaccGHIJ = _mm_add_ps(vaccGHIJ, _mm_mul_ps(viGHIJ, vw));
68 vaccKLMN = _mm_add_ps(vaccKLMN, _mm_mul_ps(viKLMN, vw));
69 vaccOPQR = _mm_add_ps(vaccOPQR, _mm_mul_ps(viOPQR, vw));
70 vaccSTUV = _mm_add_ps(vaccSTUV, _mm_mul_ps(viSTUV, vw));
71 } while (--nnz != 0);
72 }
73 __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
74 __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
75 __m128 vout89AB = _mm_min_ps(vacc89AB, vmax);
76 __m128 voutCDEF = _mm_min_ps(vaccCDEF, vmax);
77 __m128 voutGHIJ = _mm_min_ps(vaccGHIJ, vmax);
78 __m128 voutKLMN = _mm_min_ps(vaccKLMN, vmax);
79 __m128 voutOPQR = _mm_min_ps(vaccOPQR, vmax);
80 __m128 voutSTUV = _mm_min_ps(vaccSTUV, vmax);
81 vout0123 = _mm_max_ps(vout0123, vmin);
82 vout4567 = _mm_max_ps(vout4567, vmin);
83 vout89AB = _mm_max_ps(vout89AB, vmin);
84 voutCDEF = _mm_max_ps(voutCDEF, vmin);
85 voutGHIJ = _mm_max_ps(voutGHIJ, vmin);
86 voutKLMN = _mm_max_ps(voutKLMN, vmin);
87 voutOPQR = _mm_max_ps(voutOPQR, vmin);
88 voutSTUV = _mm_max_ps(voutSTUV, vmin);
89 _mm_storeu_ps(output, vout0123);
90 _mm_storeu_ps(output + 4, vout4567);
91 _mm_storeu_ps(output + 8, vout89AB);
92 _mm_storeu_ps(output + 12, voutCDEF);
93 _mm_storeu_ps(output + 16, voutGHIJ);
94 _mm_storeu_ps(output + 20, voutKLMN);
95 _mm_storeu_ps(output + 24, voutOPQR);
96 _mm_storeu_ps(output + 28, voutSTUV);
97 output = (float*restrict) ((uintptr_t) output + output_stride);
98 } while (--n != 0);
99 output = (float*restrict) ((uintptr_t) output - output_decrement);
100 input += 32;
101 mc -= 32 * sizeof(float);
102 }
103 if XNN_UNLIKELY(mc != 0) {
104 output_decrement += 16 * sizeof(float);
105 if (mc & (16 * sizeof(float))) {
106 const float*restrict w = weights;
107 const int32_t* dmap = widx_dmap;
108 const uint32_t* nnzmap = nidx_nnzmap;
109 size_t n = nc;
110 do {
111 uint32_t nnz = *nnzmap++;
112 __m128 vacc0123 = _mm_load1_ps(w); w += 1;
113 __m128 vacc4567 = vacc0123;
114 __m128 vacc89AB = vacc0123;
115 __m128 vaccCDEF = vacc0123;
116 if XNN_LIKELY(nnz != 0) {
117 do {
118 const intptr_t diff = *dmap++;
119 const __m128 vi0123 = _mm_loadu_ps(input);
120 const __m128 vi4567 = _mm_loadu_ps(input + 4);
121 const __m128 vi89AB = _mm_loadu_ps(input + 8);
122 const __m128 viCDEF = _mm_loadu_ps(input + 12);
123 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
124 const __m128 vw = _mm_load1_ps(w); w += 1;
125 vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
126 vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
127 vacc89AB = _mm_add_ps(vacc89AB, _mm_mul_ps(vi89AB, vw));
128 vaccCDEF = _mm_add_ps(vaccCDEF, _mm_mul_ps(viCDEF, vw));
129 } while (--nnz != 0);
130 }
131 __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
132 __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
133 __m128 vout89AB = _mm_min_ps(vacc89AB, vmax);
134 __m128 voutCDEF = _mm_min_ps(vaccCDEF, vmax);
135 vout0123 = _mm_max_ps(vout0123, vmin);
136 vout4567 = _mm_max_ps(vout4567, vmin);
137 vout89AB = _mm_max_ps(vout89AB, vmin);
138 voutCDEF = _mm_max_ps(voutCDEF, vmin);
139 _mm_storeu_ps(output, vout0123);
140 _mm_storeu_ps(output + 4, vout4567);
141 _mm_storeu_ps(output + 8, vout89AB);
142 _mm_storeu_ps(output + 12, voutCDEF);
143 output = (float*restrict) ((uintptr_t) output + output_stride);
144 } while (--n != 0);
145 output = (float*restrict) ((uintptr_t) output - output_decrement);
146 input += 16;
147 }
148 output_decrement += 8 * sizeof(float);
149 if (mc & (8 * sizeof(float))) {
150 const float*restrict w = weights;
151 const int32_t* dmap = widx_dmap;
152 const uint32_t* nnzmap = nidx_nnzmap;
153 size_t n = nc;
154 do {
155 uint32_t nnz = *nnzmap++;
156 __m128 vacc0123 = _mm_load1_ps(w); w += 1;
157 __m128 vacc4567 = vacc0123;
158 if XNN_LIKELY(nnz != 0) {
159 do {
160 const intptr_t diff = *dmap++;
161 const __m128 vi0123 = _mm_loadu_ps(input);
162 const __m128 vi4567 = _mm_loadu_ps(input + 4);
163 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
164 const __m128 vw = _mm_load1_ps(w); w += 1;
165 vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
166 vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
167 } while (--nnz != 0);
168 }
169 __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
170 __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
171 vout0123 = _mm_max_ps(vout0123, vmin);
172 vout4567 = _mm_max_ps(vout4567, vmin);
173 _mm_storeu_ps(output, vout0123);
174 _mm_storeu_ps(output + 4, vout4567);
175 output = (float*restrict) ((uintptr_t) output + output_stride);
176 } while (--n != 0);
177 output = (float*restrict) ((uintptr_t) output - output_decrement);
178 input += 8;
179 }
180 output_decrement += 4 * sizeof(float);
181 if (mc & (4 * sizeof(float))) {
182 const float*restrict w = weights;
183 const int32_t* dmap = widx_dmap;
184 const uint32_t* nnzmap = nidx_nnzmap;
185 size_t n = nc;
186 do {
187 uint32_t nnz = *nnzmap++;
188 __m128 vacc0123 = _mm_load1_ps(w); w += 1;
189 if XNN_LIKELY(nnz != 0) {
190 do {
191 const intptr_t diff = *dmap++;
192 const __m128 vi0123 = _mm_loadu_ps(input);
193 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
194 const __m128 vw = _mm_load1_ps(w); w += 1;
195 vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
196 } while (--nnz != 0);
197 }
198 __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
199 vout0123 = _mm_max_ps(vout0123, vmin);
200 _mm_storeu_ps(output, vout0123);
201 output = (float*restrict) ((uintptr_t) output + output_stride);
202 } while (--n != 0);
203 output = (float*restrict) ((uintptr_t) output - output_decrement);
204 input += 4;
205 }
206 output_decrement += 2 * sizeof(float);
207 if (mc & (2 * sizeof(float))) {
208 const float*restrict w = weights;
209 const int32_t* dmap = widx_dmap;
210 const uint32_t* nnzmap = nidx_nnzmap;
211 size_t n = nc;
212 do {
213 uint32_t nnz = *nnzmap++;
214 __m128 vacc01 = _mm_load_ss(w); w += 1;
215 vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
216 if XNN_LIKELY(nnz != 0) {
217 do {
218 const intptr_t diff = *dmap++;
219 const __m128 vi01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) input);
220 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
221 __m128 vw = _mm_load_ss(w); w += 1;
222 vw = _mm_unpacklo_ps(vw, vw);
223 vacc01 = _mm_add_ps(vacc01, _mm_mul_ps(vi01, vw));
224 } while (--nnz != 0);
225 }
226 __m128 vout01 = _mm_min_ps(vacc01, vmax);
227 vout01 = _mm_max_ps(vout01, vmin);
228 _mm_storel_pi((__m64*) output, vout01);
229 output = (float*restrict) ((uintptr_t) output + output_stride);
230 } while (--n != 0);
231 output = (float*restrict) ((uintptr_t) output - output_decrement);
232 input += 2;
233 }
234 output_decrement += 1 * sizeof(float);
235 if (mc & (1 * sizeof(float))) {
236 const float*restrict w = weights;
237 const int32_t* dmap = widx_dmap;
238 const uint32_t* nnzmap = nidx_nnzmap;
239 size_t n = nc;
240 do {
241 uint32_t nnz = *nnzmap++;
242 __m128 vacc0 = _mm_load_ss(w); w += 1;
243 if XNN_LIKELY(nnz != 0) {
244 do {
245 const intptr_t diff = *dmap++;
246 const __m128 vi0 = _mm_load_ss(input);
247 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
248 const __m128 vw = _mm_load_ss(w); w += 1;
249 vacc0 = _mm_add_ss(vacc0, _mm_mul_ss(vi0, vw));
250 } while (--nnz != 0);
251 }
252 __m128 vout0 = _mm_min_ss(vacc0, vmax);
253 vout0 = _mm_max_ss(vout0, vmin);
254 _mm_store_ss(output, vout0);
255 output = (float*restrict) ((uintptr_t) output + output_stride);
256 } while (--n != 0);
257 output = (float*restrict) ((uintptr_t) output - output_decrement);
258 input += 1;
259 }
260 }
261 }
262