1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon-pipelined.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_minmax_ukernel_32x1__neon_pipelined(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x1__neon_pipelined(
18     size_t mc,
19     size_t nc,
20     const float*restrict input,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict output,
25     size_t output_stride,
26     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(mc != 0);
29   assert(mc % sizeof(float) == 0);
30   assert(nc != 0);
31 
32   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
33   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
34   size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35   while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36     const float*restrict w = weights;
37     const int32_t* dmap = widx_dmap;
38     const uint32_t* nnzmap = nidx_nnzmap;
39     float32x4_t vw = vld1q_dup_f32(w); w += 1;
40     intptr_t diff = *dmap++;
41     float32x4_t vi0123 = vld1q_f32(input);
42     float32x4_t vi4567 = vld1q_f32(input + 4);
43     float32x4_t vi89AB = vld1q_f32(input + 8);
44     float32x4_t viCDEF = vld1q_f32(input + 12);
45     float32x4_t viGHIJ = vld1q_f32(input + 16);
46     float32x4_t viKLMN = vld1q_f32(input + 20);
47     float32x4_t viOPQR = vld1q_f32(input + 24);
48     float32x4_t viSTUV = vld1q_f32(input + 28);
49     size_t n = nc;
50     do {
51       uint32_t nnz = *nnzmap++;
52       float32x4_t vacc0123 = vw;
53       float32x4_t vacc4567 = vw;
54       float32x4_t vacc89AB = vw;
55       float32x4_t vaccCDEF = vw;
56       float32x4_t vaccGHIJ = vw;
57       float32x4_t vaccKLMN = vw;
58       float32x4_t vaccOPQR = vw;
59       float32x4_t vaccSTUV = vw;
60       vw = vld1q_dup_f32(w); w += 1;
61       if XNN_LIKELY(nnz != 0) {
62         do {
63           vacc0123 = vmlaq_f32(vacc0123, vi0123, vw);
64           vacc4567 = vmlaq_f32(vacc4567, vi4567, vw);
65           vacc89AB = vmlaq_f32(vacc89AB, vi89AB, vw);
66           vaccCDEF = vmlaq_f32(vaccCDEF, viCDEF, vw);
67           vaccGHIJ = vmlaq_f32(vaccGHIJ, viGHIJ, vw);
68           vaccKLMN = vmlaq_f32(vaccKLMN, viKLMN, vw);
69           vaccOPQR = vmlaq_f32(vaccOPQR, viOPQR, vw);
70           vaccSTUV = vmlaq_f32(vaccSTUV, viSTUV, vw);
71           input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
72           __builtin_prefetch(input + 16);
73           __builtin_prefetch(input + 32);
74           diff = *dmap++;
75           vw = vld1q_dup_f32(w); w += 1;
76           __builtin_prefetch(w + 32);
77           vi0123 = vld1q_f32(input);
78           vi4567 = vld1q_f32(input + 4);
79           vi89AB = vld1q_f32(input + 8);
80           viCDEF = vld1q_f32(input + 12);
81           viGHIJ = vld1q_f32(input + 16);
82           viKLMN = vld1q_f32(input + 20);
83           viOPQR = vld1q_f32(input + 24);
84           viSTUV = vld1q_f32(input + 28);
85         } while (--nnz != 0);
86       }
87       float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
88       float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
89       float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
90       float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
91       float32x4_t voutGHIJ = vminq_f32(vaccGHIJ, vmax);
92       float32x4_t voutKLMN = vminq_f32(vaccKLMN, vmax);
93       float32x4_t voutOPQR = vminq_f32(vaccOPQR, vmax);
94       float32x4_t voutSTUV = vminq_f32(vaccSTUV, vmax);
95       vout0123 = vmaxq_f32(vout0123, vmin);
96       vout4567 = vmaxq_f32(vout4567, vmin);
97       vout89AB = vmaxq_f32(vout89AB, vmin);
98       voutCDEF = vmaxq_f32(voutCDEF, vmin);
99       voutGHIJ = vmaxq_f32(voutGHIJ, vmin);
100       voutKLMN = vmaxq_f32(voutKLMN, vmin);
101       voutOPQR = vmaxq_f32(voutOPQR, vmin);
102       voutSTUV = vmaxq_f32(voutSTUV, vmin);
103       vst1q_f32(output, vout0123);
104       vst1q_f32(output + 4, vout4567);
105       vst1q_f32(output + 8, vout89AB);
106       vst1q_f32(output + 12, voutCDEF);
107       vst1q_f32(output + 16, voutGHIJ);
108       vst1q_f32(output + 20, voutKLMN);
109       vst1q_f32(output + 24, voutOPQR);
110       vst1q_f32(output + 28, voutSTUV);
111       output = (float*restrict) ((uintptr_t) output + output_stride);
112     } while (--n != 0);
113     output = (float*restrict) ((uintptr_t) output - output_decrement);
114     input += 32;
115     mc -= 32 * sizeof(float);
116   }
117   if XNN_UNLIKELY(mc != 0) {
118     output_decrement += 16 * sizeof(float);
119     if (mc & (16 * sizeof(float))) {
120       const float*restrict w = weights;
121       const int32_t* dmap = widx_dmap;
122       const uint32_t* nnzmap = nidx_nnzmap;
123       size_t n = nc;
124       do {
125         uint32_t nnz = *nnzmap++;
126         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
127         float32x4_t vacc4567 = vacc0123;
128         float32x4_t vacc89AB = vacc0123;
129         float32x4_t vaccCDEF = vacc0123;
130         if XNN_LIKELY(nnz != 0) {
131           do {
132             const intptr_t diff = *dmap++;
133             const float32x4_t vi0123 = vld1q_f32(input);
134             const float32x4_t vi4567 = vld1q_f32(input + 4);
135             const float32x4_t vi89AB = vld1q_f32(input + 8);
136             const float32x4_t viCDEF = vld1q_f32(input + 12);
137             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
138             __builtin_prefetch(input + 16);
139             __builtin_prefetch(input + 32);
140             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
141             __builtin_prefetch(w + 32);
142             vacc0123 = vmlaq_f32(vacc0123, vi0123, vb);
143             vacc4567 = vmlaq_f32(vacc4567, vi4567, vb);
144             vacc89AB = vmlaq_f32(vacc89AB, vi89AB, vb);
145             vaccCDEF = vmlaq_f32(vaccCDEF, viCDEF, vb);
146           } while (--nnz != 0);
147         }
148         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
149         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
150         float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
151         float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
152         vout0123 = vmaxq_f32(vout0123, vmin);
153         vout4567 = vmaxq_f32(vout4567, vmin);
154         vout89AB = vmaxq_f32(vout89AB, vmin);
155         voutCDEF = vmaxq_f32(voutCDEF, vmin);
156         vst1q_f32(output, vout0123);
157         vst1q_f32(output + 4, vout4567);
158         vst1q_f32(output + 8, vout89AB);
159         vst1q_f32(output + 12, voutCDEF);
160         output = (float*restrict) ((uintptr_t) output + output_stride);
161       } while (--n != 0);
162       output = (float*restrict) ((uintptr_t) output - output_decrement);
163       input += 16;
164     }
165     output_decrement += 8 * sizeof(float);
166     if (mc & (8 * sizeof(float))) {
167       const float*restrict w = weights;
168       const int32_t* dmap = widx_dmap;
169       const uint32_t* nnzmap = nidx_nnzmap;
170       size_t n = nc;
171       do {
172         uint32_t nnz = *nnzmap++;
173         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
174         float32x4_t vacc4567 = vacc0123;
175         if XNN_LIKELY(nnz != 0) {
176           do {
177             const intptr_t diff = *dmap++;
178             const float32x4_t vi0123 = vld1q_f32(input);
179             const float32x4_t vi4567 = vld1q_f32(input + 4);
180             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
181             __builtin_prefetch(input + 16);
182             __builtin_prefetch(input + 32);
183             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
184             __builtin_prefetch(w + 32);
185             vacc0123 = vmlaq_f32(vacc0123, vi0123, vb);
186             vacc4567 = vmlaq_f32(vacc4567, vi4567, vb);
187           } while (--nnz != 0);
188         }
189         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
190         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
191         vout0123 = vmaxq_f32(vout0123, vmin);
192         vout4567 = vmaxq_f32(vout4567, vmin);
193         vst1q_f32(output, vout0123);
194         vst1q_f32(output + 4, vout4567);
195         output = (float*restrict) ((uintptr_t) output + output_stride);
196       } while (--n != 0);
197       output = (float*restrict) ((uintptr_t) output - output_decrement);
198       input += 8;
199     }
200     output_decrement += 4 * sizeof(float);
201     if (mc & (4 * sizeof(float))) {
202       const float*restrict w = weights;
203       const int32_t* dmap = widx_dmap;
204       const uint32_t* nnzmap = nidx_nnzmap;
205       size_t n = nc;
206       do {
207         uint32_t nnz = *nnzmap++;
208         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
209         if XNN_LIKELY(nnz != 0) {
210           do {
211             const intptr_t diff = *dmap++;
212             const float32x4_t vi0123 = vld1q_f32(input);
213             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
214             __builtin_prefetch(input + 16);
215             __builtin_prefetch(input + 32);
216             const float32x4_t vb = vld1q_dup_f32(w); w += 1;
217             __builtin_prefetch(w + 32);
218             vacc0123 = vmlaq_f32(vacc0123, vi0123, vb);
219           } while (--nnz != 0);
220         }
221         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
222         vout0123 = vmaxq_f32(vout0123, vmin);
223         vst1q_f32(output, vout0123);
224         output = (float*restrict) ((uintptr_t) output + output_stride);
225       } while (--n != 0);
226       output = (float*restrict) ((uintptr_t) output - output_decrement);
227       input += 4;
228     }
229     output_decrement += 2 * sizeof(float);
230     if (mc & (2 * sizeof(float))) {
231       const float*restrict w = weights;
232       const int32_t* dmap = widx_dmap;
233       const uint32_t* nnzmap = nidx_nnzmap;
234       size_t n = nc;
235       do {
236         uint32_t nnz = *nnzmap++;
237         float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
238         if XNN_LIKELY(nnz != 0) {
239           do {
240             const intptr_t diff = *dmap++;
241             const float32x2_t vi01 = vld1_f32(input);
242             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
243             __builtin_prefetch(input + 16);
244             __builtin_prefetch(input + 32);
245             const float32x2_t vb = vld1_dup_f32(w); w += 1;
246             __builtin_prefetch(w + 32);
247             vacc01 = vmla_f32(vacc01, vi01, vb);
248           } while (--nnz != 0);
249         }
250         float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
251         vout01 = vmax_f32(vout01, vget_low_f32(vmin));
252         vst1_f32(output, vout01);
253         output = (float*restrict) ((uintptr_t) output + output_stride);
254       } while (--n != 0);
255       output = (float*restrict) ((uintptr_t) output - output_decrement);
256       input += 2;
257     }
258     output_decrement += 1 * sizeof(float);
259     if (mc & (1 * sizeof(float))) {
260       const float*restrict w = weights;
261       const int32_t* dmap = widx_dmap;
262       const uint32_t* nnzmap = nidx_nnzmap;
263       size_t n = nc;
264       do {
265         uint32_t nnz = *nnzmap++;
266         float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
267         if XNN_LIKELY(nnz != 0) {
268           do {
269             const intptr_t diff = *dmap++;
270             const float32x2_t vi0 = vld1_dup_f32(input);
271             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
272             __builtin_prefetch(input + 16);
273             __builtin_prefetch(input + 32);
274             const float32x2_t vb = vld1_dup_f32(w); w += 1;
275             __builtin_prefetch(w + 32);
276             vacc0 = vmla_f32(vacc0, vi0, vb);
277           } while (--nnz != 0);
278         }
279         float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
280         vout0 = vmax_f32(vout0, vget_low_f32(vmin));
281         vst1_lane_f32(output, vout0, 0);
282         output = (float*restrict) ((uintptr_t) output + output_stride);
283       } while (--n != 0);
284       output = (float*restrict) ((uintptr_t) output - output_decrement);
285       input += 1;
286     }
287   }
288 }
289