1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <xmmintrin.h>
13
14#include <xnnpack/math.h>
15#include <xnnpack/vmulcaddc.h>
16
17
18void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__sse_${ROW_TILE}x(
19    size_t rows,
20    size_t channels,
21    const float*restrict input,
22    size_t input_stride,
23    const float*restrict weights,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
27{
28  assert(rows != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  const float* i0 = input;
33  float* o0 = output;
34  $for M in range(1, ROW_TILE):
35    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
36    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
37    $if M % 2 == 0:
38      if XNN_UNPREDICTABLE(rows <= ${M}) {
39        i${M} = i${M-1};
40        o${M} = o${M-1};
41      }
42    $else:
43      if XNN_UNPREDICTABLE(rows < ${M+1}) {
44        i${M} = i${M-1};
45        o${M} = o${M-1};
46      }
47
48  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
49  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
50
51  const __m128 vmin = _mm_load_ps(params->sse.min);
52  const __m128 vmax = _mm_load_ps(params->sse.max);
53  do {
54    const float* w = weights;
55    size_t c = channels;
56    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
57      const __m128 vscale${ABC[0:4]} = _mm_load_ps(w);
58      $for C in range(4, CHANNEL_TILE, 4):
59        const __m128 vscale${ABC[C:C+4]} = _mm_load_ps(w + ${C});
60
61      $for M in range(ROW_TILE):
62        __m128 vacc${M}x${ABC[0:4]} = _mm_loadu_ps(i${M});
63        $for C in range(4, CHANNEL_TILE, 4):
64          __m128 vacc${M}x${ABC[C:C+4]} = _mm_loadu_ps(i${M} + ${C});
65        i${M} += ${CHANNEL_TILE};
66
67      $for M in range(ROW_TILE):
68        $for C in range(0, CHANNEL_TILE, 4):
69          vacc${M}x${ABC[C:C+4]} = _mm_mul_ps(vacc${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
70
71      $for C in range(0, CHANNEL_TILE, 4):
72        const __m128 vbias${ABC[C:C+4]} = _mm_load_ps(w + ${C + CHANNEL_TILE});
73
74      $for M in range(ROW_TILE):
75        $for C in range(0, CHANNEL_TILE, 4):
76          vacc${M}x${ABC[C:C+4]} = _mm_add_ps(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
77
78      $for M in range(ROW_TILE):
79        $for C in range(0, CHANNEL_TILE, 4):
80          vacc${M}x${ABC[C:C+4]} = _mm_max_ps(vacc${M}x${ABC[C:C+4]}, vmin);
81
82      $for M in range(ROW_TILE):
83        $for C in range(0, CHANNEL_TILE, 4):
84          vacc${M}x${ABC[C:C+4]} = _mm_min_ps(vacc${M}x${ABC[C:C+4]}, vmax);
85
86      $for M in range(ROW_TILE):
87        _mm_storeu_ps(o${M}, vacc${M}x${ABC[0:4]});
88        $for C in range(4, CHANNEL_TILE, 4):
89          _mm_storeu_ps(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
90        o${M} += ${CHANNEL_TILE};
91
92      w += ${CHANNEL_TILE * 2};
93    }
94    $if CHANNEL_TILE > 4:
95      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
96        const __m128 vscale0123 = _mm_load_ps(w);
97
98        $for M in range(ROW_TILE):
99          __m128 vacc${M}x0123 = _mm_loadu_ps(i${M});
100          i${M} += 4;
101
102        $for M in range(ROW_TILE):
103          vacc${M}x0123 = _mm_mul_ps(vacc${M}x0123, vscale0123);
104
105        const __m128 vbias0123 = _mm_load_ps(w + ${CHANNEL_TILE});
106
107        $for M in range(ROW_TILE):
108          vacc${M}x0123 = _mm_add_ps(vacc${M}x0123, vbias0123);
109
110        $for M in range(ROW_TILE):
111          vacc${M}x0123 = _mm_max_ps(vacc${M}x0123, vmin);
112
113        $for M in range(ROW_TILE):
114          vacc${M}x0123 = _mm_min_ps(vacc${M}x0123, vmax);
115
116        $for M in range(ROW_TILE):
117          _mm_storeu_ps(o${M}, vacc${M}x0123);
118          o${M} += 4;
119
120        w += 4;
121      }
122    if XNN_UNLIKELY(c != 0) {
123      const __m128 vscale0123 = _mm_load_ps(w);
124
125      $for M in range(ROW_TILE):
126        __m128 vacc${M}x0123 = _mm_loadu_ps(i${M});
127        i${M} = (const float*) ((uintptr_t) i${M} + c);
128
129      $for M in range(ROW_TILE):
130        vacc${M}x0123 = _mm_mul_ps(vacc${M}x0123, vscale0123);
131
132      const __m128 vbias0123 = _mm_load_ps(w + ${CHANNEL_TILE});
133
134      $for M in range(ROW_TILE):
135        vacc${M}x0123 = _mm_add_ps(vacc${M}x0123, vbias0123);
136
137      $for M in range(ROW_TILE):
138        vacc${M}x0123 = _mm_max_ps(vacc${M}x0123, vmin);
139
140      $for M in range(ROW_TILE):
141        vacc${M}x0123 = _mm_min_ps(vacc${M}x0123, vmax);
142
143      if (c & (2 * sizeof(float))) {
144        $for M in range(ROW_TILE):
145          _mm_storel_pi((__m64*) o${M}, vacc${M}x0123);
146
147        $for M in range(ROW_TILE):
148          vacc${M}x0123 = _mm_movehl_ps(vacc${M}x0123, vacc${M}x0123);
149
150        $for M in range(ROW_TILE):
151          o${M} += 2;
152      }
153      if (c & (1 * sizeof(float))) {
154        $for M in range(ROW_TILE):
155          _mm_store_ss(o${M}, vacc${M}x0123);
156
157        $for M in range(ROW_TILE):
158          o${M} += 1;
159      }
160    }
161    $for M in range(ROW_TILE):
162      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
163      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
164      $if M % 2 == 1:
165        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) {
166          i${M} = i${M-1};
167          o${M} = o${M-1};
168        }
169      $elif M != 0:
170        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
171          i${M} = i${M-1};
172          o${M} = o${M-1};
173        }
174    rows = doz(rows, ${ROW_TILE});
175  } while (rows != 0);
176}
177