1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8$VMULADD_F32 = "vfma_f32" if FMA else "vmla_f32"
9$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
10#include <assert.h>
11
12#include <arm_neon.h>
13
14#include <xnnpack/spmm.h>
15
16
17void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
18    size_t mc,
19    size_t nc,
20    const float*restrict input,
21    const float*restrict weights,
22    const int32_t*restrict widx_dmap,
23    const uint32_t*restrict nidx_nnzmap,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27{
28  assert(mc != 0);
29  assert(mc % sizeof(float) == 0);
30  assert(nc != 0);
31
32  const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
33  const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
34  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
35  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
36    const float*restrict w = weights;
37    const int32_t* dmap = widx_dmap;
38    const uint32_t* nnzmap = nidx_nnzmap;
39    size_t n = nc;
40    do {
41      uint32_t nnz = *nnzmap++;
42      $if UNROLL > 1:
43        float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
44        $for K in range(1, UNROLL):
45          float32x4_t vacc0123x${K} = vmovq_n_f32(0.0f);
46        $for M in range(4, MR, 4):
47          float32x4_t vacc${ABC[M:M+4]}x0 = vacc0123x0;
48          $for K in range(1, UNROLL):
49            float32x4_t vacc${ABC[M:M+4]}x${K} = vmovq_n_f32(0.0f);
50        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
51          $for K in range(UNROLL):
52            const intptr_t diff${K} = dmap[${K}];
53          dmap += ${UNROLL};
54          $for K in range(UNROLL):
55            const float32x4_t vi0123x${K} = vld1q_f32(input);
56            $for M in range(4, MR, 4):
57              const float32x4_t vi${ABC[M:M+4]}x${K} = vld1q_f32(input + ${M});
58            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K});
59            $for M in range(0, MR, 16):
60              __builtin_prefetch(input + ${M+16});
61            const float32x4_t vw${K} = vld1q_dup_f32(w); w += 1;
62            __builtin_prefetch(w + 32);
63            $for M in range(0, MR, 4):
64              vacc${ABC[M:M+4]}x${K} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}x${K}, vi${ABC[M:M+4]}x${K}, vw${K});
65        }
66        $for M in range(0, MR, 4):
67          float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
68        $for K in range(1, UNROLL):
69          $for M in range(0, MR, 4):
70            vacc${ABC[M:M+4]} = vaddq_f32(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
71      $else:
72        float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
73        $for M in range(4, MR, 4):
74          float32x4_t vacc${ABC[M:M+4]} = vacc0123;
75      if XNN_LIKELY(nnz != 0) {
76        do {
77          const intptr_t diff = *dmap++;
78          const float32x4_t vi0123 = vld1q_f32(input);
79          $for M in range(4, MR, 4):
80            const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
81          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
82          $for M in range(0, MR, 16):
83            __builtin_prefetch(input + ${M+16});
84          const float32x4_t vw = vld1q_dup_f32(w); w += 1;
85          __builtin_prefetch(w + 32);
86          $for M in range(0, MR, 4):
87            vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw);
88        } while (--nnz != 0);
89      }
90      $for M in range(0, MR, 4):
91        float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
92      $for M in range(0, MR, 4):
93        vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
94      vst1q_f32(output, vout0123);
95      $for M in range(4, MR, 4):
96        vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
97      output = (float*restrict) ((uintptr_t) output + output_stride);
98    } while (--n != 0);
99    output = (float*restrict) ((uintptr_t) output - output_decrement);
100    input += ${MR};
101    mc -= ${MR} * sizeof(float);
102  }
103  if XNN_UNLIKELY(mc != 0) {
104    $for LOG2M in reversed(range((MR - 1).bit_length())):
105      $SUBMR = 1 << LOG2M
106      $if SUBMR * 2 >= MR:
107        output_decrement += ${MR - SUBMR} * sizeof(float);
108      $else:
109        output_decrement += ${SUBMR} * sizeof(float);
110      if (mc & (${SUBMR} * sizeof(float))) {
111        const float*restrict w = weights;
112        const int32_t* dmap = widx_dmap;
113        const uint32_t* nnzmap = nidx_nnzmap;
114        size_t n = nc;
115        do {
116          uint32_t nnz = *nnzmap++;
117          $if SUBMR <= 2:
118            float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
119          $else:
120            float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
121          $for M in range(4, SUBMR, 4):
122            float32x4_t vacc${ABC[M:M+4]} = vacc0123;
123          if XNN_LIKELY(nnz != 0) {
124            do {
125              const intptr_t diff = *dmap++;
126              $if SUBMR == 1:
127                const float32x2_t vi0 = vld1_dup_f32(input);
128              $elif SUBMR == 2:
129                const float32x2_t vi01 = vld1_f32(input);
130              $else:
131                const float32x4_t vi0123 = vld1q_f32(input);
132              $for M in range(4, SUBMR, 4):
133                const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
134              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
135              $if SUBMR <= 2:
136                const float32x2_t vw = vld1_dup_f32(w); w += 1;
137              $else:
138                const float32x4_t vw = vld1q_dup_f32(w); w += 1;
139              $if SUBMR <= 2:
140                vacc${ABC[0:SUBMR]} = ${VMULADD_F32}(vacc${ABC[0:SUBMR]}, vi${ABC[0:SUBMR]}, vw);
141              $else:
142                $for M in range(0, SUBMR, 4):
143                  vacc${ABC[M:M+4]} = ${VMULADDQ_F32}(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw);
144            } while (--nnz != 0);
145          }
146          $if SUBMR <= 2:
147            float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
148            vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
149            $if SUBMR == 1:
150              vst1_lane_f32(output, vout${ABC[0]}, 0);
151            $else:
152              vst1_f32(output, vout${ABC[0:SUBMR]});
153          $else:
154            $for M in range(0, SUBMR, 4):
155              float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
156            $for M in range(0, SUBMR, 4):
157              vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
158            vst1q_f32(output, vout0123);
159            $for M in range(4, SUBMR, 4):
160              vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
161          output = (float*restrict) ((uintptr_t) output + output_stride);
162        } while (--n != 0);
163        output = (float*restrict) ((uintptr_t) output - output_decrement);
164        input += ${SUBMR};
165      }
166  }
167}
168