1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/sse.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <immintrin.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_minmax_ukernel_32x1__sse(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x1__sse(
18     size_t mc,
19     size_t nc,
20     const float*restrict input,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict output,
25     size_t output_stride,
26     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(mc != 0);
29   assert(mc % sizeof(float) == 0);
30   assert(nc != 0);
31 
32   const __m128 vmin = _mm_load_ps(params->sse.min);
33   const __m128 vmax = _mm_load_ps(params->sse.max);
34   size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35   while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36     const float*restrict w = weights;
37     const int32_t* dmap = widx_dmap;
38     const uint32_t* nnzmap = nidx_nnzmap;
39     size_t n = nc;
40     do {
41       uint32_t nnz = *nnzmap++;
42       __m128 vacc0123 = _mm_load1_ps(w); w += 1;
43       __m128 vacc4567 = vacc0123;
44       __m128 vacc89AB = vacc0123;
45       __m128 vaccCDEF = vacc0123;
46       __m128 vaccGHIJ = vacc0123;
47       __m128 vaccKLMN = vacc0123;
48       __m128 vaccOPQR = vacc0123;
49       __m128 vaccSTUV = vacc0123;
50       if XNN_LIKELY(nnz != 0) {
51         do {
52           const intptr_t diff = *dmap++;
53           const __m128 vi0123 = _mm_loadu_ps(input);
54           const __m128 vi4567 = _mm_loadu_ps(input + 4);
55           const __m128 vi89AB = _mm_loadu_ps(input + 8);
56           const __m128 viCDEF = _mm_loadu_ps(input + 12);
57           const __m128 viGHIJ = _mm_loadu_ps(input + 16);
58           const __m128 viKLMN = _mm_loadu_ps(input + 20);
59           const __m128 viOPQR = _mm_loadu_ps(input + 24);
60           const __m128 viSTUV = _mm_loadu_ps(input + 28);
61           input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
62           const __m128 vw = _mm_load1_ps(w); w += 1;
63           vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
64           vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
65           vacc89AB = _mm_add_ps(vacc89AB, _mm_mul_ps(vi89AB, vw));
66           vaccCDEF = _mm_add_ps(vaccCDEF, _mm_mul_ps(viCDEF, vw));
67           vaccGHIJ = _mm_add_ps(vaccGHIJ, _mm_mul_ps(viGHIJ, vw));
68           vaccKLMN = _mm_add_ps(vaccKLMN, _mm_mul_ps(viKLMN, vw));
69           vaccOPQR = _mm_add_ps(vaccOPQR, _mm_mul_ps(viOPQR, vw));
70           vaccSTUV = _mm_add_ps(vaccSTUV, _mm_mul_ps(viSTUV, vw));
71         } while (--nnz != 0);
72       }
73       __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
74       __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
75       __m128 vout89AB = _mm_min_ps(vacc89AB, vmax);
76       __m128 voutCDEF = _mm_min_ps(vaccCDEF, vmax);
77       __m128 voutGHIJ = _mm_min_ps(vaccGHIJ, vmax);
78       __m128 voutKLMN = _mm_min_ps(vaccKLMN, vmax);
79       __m128 voutOPQR = _mm_min_ps(vaccOPQR, vmax);
80       __m128 voutSTUV = _mm_min_ps(vaccSTUV, vmax);
81       vout0123 = _mm_max_ps(vout0123, vmin);
82       vout4567 = _mm_max_ps(vout4567, vmin);
83       vout89AB = _mm_max_ps(vout89AB, vmin);
84       voutCDEF = _mm_max_ps(voutCDEF, vmin);
85       voutGHIJ = _mm_max_ps(voutGHIJ, vmin);
86       voutKLMN = _mm_max_ps(voutKLMN, vmin);
87       voutOPQR = _mm_max_ps(voutOPQR, vmin);
88       voutSTUV = _mm_max_ps(voutSTUV, vmin);
89       _mm_storeu_ps(output, vout0123);
90       _mm_storeu_ps(output + 4, vout4567);
91       _mm_storeu_ps(output + 8, vout89AB);
92       _mm_storeu_ps(output + 12, voutCDEF);
93       _mm_storeu_ps(output + 16, voutGHIJ);
94       _mm_storeu_ps(output + 20, voutKLMN);
95       _mm_storeu_ps(output + 24, voutOPQR);
96       _mm_storeu_ps(output + 28, voutSTUV);
97       output = (float*restrict) ((uintptr_t) output + output_stride);
98     } while (--n != 0);
99     output = (float*restrict) ((uintptr_t) output - output_decrement);
100     input += 32;
101     mc -= 32 * sizeof(float);
102   }
103   if XNN_UNLIKELY(mc != 0) {
104     output_decrement += 16 * sizeof(float);
105     if (mc & (16 * sizeof(float))) {
106       const float*restrict w = weights;
107       const int32_t* dmap = widx_dmap;
108       const uint32_t* nnzmap = nidx_nnzmap;
109       size_t n = nc;
110       do {
111         uint32_t nnz = *nnzmap++;
112         __m128 vacc0123 = _mm_load1_ps(w); w += 1;
113         __m128 vacc4567 = vacc0123;
114         __m128 vacc89AB = vacc0123;
115         __m128 vaccCDEF = vacc0123;
116         if XNN_LIKELY(nnz != 0) {
117           do {
118             const intptr_t diff = *dmap++;
119             const __m128 vi0123 = _mm_loadu_ps(input);
120             const __m128 vi4567 = _mm_loadu_ps(input + 4);
121             const __m128 vi89AB = _mm_loadu_ps(input + 8);
122             const __m128 viCDEF = _mm_loadu_ps(input + 12);
123             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
124             const __m128 vw = _mm_load1_ps(w); w += 1;
125             vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
126             vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
127             vacc89AB = _mm_add_ps(vacc89AB, _mm_mul_ps(vi89AB, vw));
128             vaccCDEF = _mm_add_ps(vaccCDEF, _mm_mul_ps(viCDEF, vw));
129           } while (--nnz != 0);
130         }
131         __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
132         __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
133         __m128 vout89AB = _mm_min_ps(vacc89AB, vmax);
134         __m128 voutCDEF = _mm_min_ps(vaccCDEF, vmax);
135         vout0123 = _mm_max_ps(vout0123, vmin);
136         vout4567 = _mm_max_ps(vout4567, vmin);
137         vout89AB = _mm_max_ps(vout89AB, vmin);
138         voutCDEF = _mm_max_ps(voutCDEF, vmin);
139         _mm_storeu_ps(output, vout0123);
140         _mm_storeu_ps(output + 4, vout4567);
141         _mm_storeu_ps(output + 8, vout89AB);
142         _mm_storeu_ps(output + 12, voutCDEF);
143         output = (float*restrict) ((uintptr_t) output + output_stride);
144       } while (--n != 0);
145       output = (float*restrict) ((uintptr_t) output - output_decrement);
146       input += 16;
147     }
148     output_decrement += 8 * sizeof(float);
149     if (mc & (8 * sizeof(float))) {
150       const float*restrict w = weights;
151       const int32_t* dmap = widx_dmap;
152       const uint32_t* nnzmap = nidx_nnzmap;
153       size_t n = nc;
154       do {
155         uint32_t nnz = *nnzmap++;
156         __m128 vacc0123 = _mm_load1_ps(w); w += 1;
157         __m128 vacc4567 = vacc0123;
158         if XNN_LIKELY(nnz != 0) {
159           do {
160             const intptr_t diff = *dmap++;
161             const __m128 vi0123 = _mm_loadu_ps(input);
162             const __m128 vi4567 = _mm_loadu_ps(input + 4);
163             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
164             const __m128 vw = _mm_load1_ps(w); w += 1;
165             vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
166             vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(vi4567, vw));
167           } while (--nnz != 0);
168         }
169         __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
170         __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
171         vout0123 = _mm_max_ps(vout0123, vmin);
172         vout4567 = _mm_max_ps(vout4567, vmin);
173         _mm_storeu_ps(output, vout0123);
174         _mm_storeu_ps(output + 4, vout4567);
175         output = (float*restrict) ((uintptr_t) output + output_stride);
176       } while (--n != 0);
177       output = (float*restrict) ((uintptr_t) output - output_decrement);
178       input += 8;
179     }
180     output_decrement += 4 * sizeof(float);
181     if (mc & (4 * sizeof(float))) {
182       const float*restrict w = weights;
183       const int32_t* dmap = widx_dmap;
184       const uint32_t* nnzmap = nidx_nnzmap;
185       size_t n = nc;
186       do {
187         uint32_t nnz = *nnzmap++;
188         __m128 vacc0123 = _mm_load1_ps(w); w += 1;
189         if XNN_LIKELY(nnz != 0) {
190           do {
191             const intptr_t diff = *dmap++;
192             const __m128 vi0123 = _mm_loadu_ps(input);
193             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
194             const __m128 vw = _mm_load1_ps(w); w += 1;
195             vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(vi0123, vw));
196           } while (--nnz != 0);
197         }
198         __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
199         vout0123 = _mm_max_ps(vout0123, vmin);
200         _mm_storeu_ps(output, vout0123);
201         output = (float*restrict) ((uintptr_t) output + output_stride);
202       } while (--n != 0);
203       output = (float*restrict) ((uintptr_t) output - output_decrement);
204       input += 4;
205     }
206     output_decrement += 2 * sizeof(float);
207     if (mc & (2 * sizeof(float))) {
208       const float*restrict w = weights;
209       const int32_t* dmap = widx_dmap;
210       const uint32_t* nnzmap = nidx_nnzmap;
211       size_t n = nc;
212       do {
213         uint32_t nnz = *nnzmap++;
214         __m128 vacc01 = _mm_load_ss(w); w += 1;
215         vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
216         if XNN_LIKELY(nnz != 0) {
217           do {
218             const intptr_t diff = *dmap++;
219             const __m128 vi01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) input);
220             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
221             __m128 vw = _mm_load_ss(w); w += 1;
222             vw = _mm_unpacklo_ps(vw, vw);
223             vacc01 = _mm_add_ps(vacc01, _mm_mul_ps(vi01, vw));
224           } while (--nnz != 0);
225         }
226         __m128 vout01 = _mm_min_ps(vacc01, vmax);
227         vout01 = _mm_max_ps(vout01, vmin);
228         _mm_storel_pi((__m64*) output, vout01);
229         output = (float*restrict) ((uintptr_t) output + output_stride);
230       } while (--n != 0);
231       output = (float*restrict) ((uintptr_t) output - output_decrement);
232       input += 2;
233     }
234     output_decrement += 1 * sizeof(float);
235     if (mc & (1 * sizeof(float))) {
236       const float*restrict w = weights;
237       const int32_t* dmap = widx_dmap;
238       const uint32_t* nnzmap = nidx_nnzmap;
239       size_t n = nc;
240       do {
241         uint32_t nnz = *nnzmap++;
242         __m128 vacc0 = _mm_load_ss(w); w += 1;
243         if XNN_LIKELY(nnz != 0) {
244           do {
245             const intptr_t diff = *dmap++;
246             const __m128 vi0 = _mm_load_ss(input);
247             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
248             const __m128 vw = _mm_load_ss(w); w += 1;
249             vacc0 = _mm_add_ss(vacc0, _mm_mul_ss(vi0, vw));
250           } while (--nnz != 0);
251         }
252         __m128 vout0 = _mm_min_ss(vacc0, vmax);
253         vout0 = _mm_max_ss(vout0, vmin);
254         _mm_store_ss(output, vout0);
255         output = (float*restrict) ((uintptr_t) output + output_stride);
256       } while (--n != 0);
257       output = (float*restrict) ((uintptr_t) output - output_decrement);
258       input += 1;
259     }
260   }
261 }
262