1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert CHANNEL_TILE % 16 == 0
8$assert CHANNEL_TILE >= 16
9$assert KERNEL_TILE >= 2
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/dwconv.h>
15
16
17void xnn_qs8_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul16(
18    size_t channels,
19    size_t output_width,
20    const int8_t** input,
21    const void* weights,
22    int8_t* output,
23    size_t input_stride,
24    size_t output_increment,
25    size_t input_offset,
26    const int8_t* zero,
27    const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
28{
29  assert(channels != 0);
30  assert(output_width != 0);
31
32  do {
33    $for K in range(KERNEL_TILE):
34      const int8_t* i${K} = input[${K}];
35      assert(i${K} != NULL);
36      if XNN_UNPREDICTABLE(i${K} != zero) {
37        i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset);
38      }
39    input = (const int8_t**) ((uintptr_t) input + input_stride);
40
41    size_t c = channels;
42    const void* w = weights;
43    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
44      __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
45      $for C in range(8, CHANNEL_TILE, 8):
46        __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
47
48      $for K in range(KERNEL_TILE):
49
50        $for C in range(0, CHANNEL_TILE, 16):
51          $if C == 0:
52            const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
53          $else:
54            const __m256i vi${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
55          const __m256i vk${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t))));
56        i${K} += ${CHANNEL_TILE};
57
58        $for C in range(0, CHANNEL_TILE, 16):
59          const __m256i vprod${K}x${ABC[C:C+16]} =  _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
60          const __m128i vprod${K}x${ABC[C+8:C+16]} = _mm256_extracti128_si256(vprod${K}x${ABC[C:C+16]}, 1);
61          vacc${ABC[C:C+8]} = _mm256_add_epi32(vacc${ABC[C:C+8]}, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vprod${K}x${ABC[C:C+16]})));
62          vacc${ABC[C+8:C+16]} = _mm256_add_epi32(vacc${ABC[C+8:C+16]}, _mm256_cvtepi16_epi32(vprod${K}x${ABC[C+8:C+16]}));
63
64      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t));
65
66      const __m256i vmultiplier = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.multiplier));
67      const __m256i vrounding = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.rounding));
68
69      $for C in range(0, CHANNEL_TILE, 8):
70        const __m256i vacc${ABC[C+1:C+8:2]} = _mm256_shuffle_epi32(vacc${ABC[C:C+8]}, _MM_SHUFFLE(3, 3, 1, 1));
71
72      $for C in range(0, CHANNEL_TILE, 8):
73        const __m256i vprod${ABC[C:C+8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[C:C+8]}, vmultiplier), vrounding);
74        const __m256i vprod${ABC[C+1:C+8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[C+1:C+8:2]}, vmultiplier), vrounding);
75
76      $for C in range(0, CHANNEL_TILE, 8):
77        const __m256i vq31prod${ABC[C:C+8:2]} = _mm256_srli_epi64(vprod${ABC[C:C+8:2]}, 31);
78        const __m256i vq31prod${ABC[C+1:C+8:2]} = _mm256_add_epi64(vprod${ABC[C+1:C+8:2]}, vprod${ABC[C+1:C+8:2]});
79
80      $for C in range(0, CHANNEL_TILE, 8):
81        const __m256i vq31prod${ABC[C:C+8]} = _mm256_blend_epi16(vq31prod${ABC[C:C+8:2]}, vq31prod${ABC[C+1:C+8:2]}, 0xCC);
82
83      const __m256i vremainder_mask = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_mask));
84      $for C in range(0, CHANNEL_TILE, 8):
85        const __m256i vrem${ABC[C:C+8]} =
86          _mm256_add_epi32(_mm256_and_si256(vq31prod${ABC[C:C+8]}, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${ABC[C:C+8]}));
87
88      const __m256i vremainder_threshold = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_threshold));
89      const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
90      $for C in range(0, CHANNEL_TILE, 8):
91        vacc${ABC[C:C+8]} =
92          _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${ABC[C:C+8]}, vshift), _mm256_cmpgt_epi32(vrem${ABC[C:C+8]}, vremainder_threshold));
93
94      const __m256i voutput_zero_point = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_zero_point));
95      $for C in range(0, CHANNEL_TILE, 16):
96        __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point);
97
98      const __m256i voutput_min = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_min));
99      const __m256i voutput_max = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_max));
100      $for C in range(0, CHANNEL_TILE, 16):
101        vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_min_epi16(_mm256_max_epi16(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, voutput_min), voutput_max);
102
103      $for C in range(0, CHANNEL_TILE, 16):
104        __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0));
105
106      _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
107      $for C in range(16, CHANNEL_TILE, 16):
108        _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
109      output += ${CHANNEL_TILE};
110    }
111    if XNN_UNLIKELY(c != 0) {
112      $if CHANNEL_TILE > 16:
113        const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
114      ${"do " if CHANNEL_TILE > 16 else ""}{
115        __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
116        __m256i vacc${ABC[8:16]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + 8 * sizeof(int32_t)));
117
118        $for K in range(KERNEL_TILE):
119
120          const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
121          $if CHANNEL_TILE > 16:
122            $if K == 0:
123              const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) k));
124            $else:
125              const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE})));
126          $else:
127            const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t))));
128          $if CHANNEL_TILE > 16:
129            i${K} += 16;
130
131          const __m256i vprod${K}x${ABC[0:16]} = _mm256_mullo_epi16(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
132          const __m128i vprod${K}x${ABC[8:16]} = _mm256_extracti128_si256(vprod${K}x${ABC[0:16]}, 1);
133          vacc${ABC[0:8]} = _mm256_add_epi32(vacc${ABC[0:8]}, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vprod${K}x${ABC[0:16]})));
134          vacc${ABC[8:16]} = _mm256_add_epi32(vacc${ABC[8:16]}, _mm256_cvtepi16_epi32(vprod${K}x${ABC[8:16]}));
135
136        $if CHANNEL_TILE > 16:
137          w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
138          k += 16;
139
140        const __m256i vmultiplier = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.multiplier));
141        const __m256i vrounding = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.rounding));
142
143        const __m256i vacc${ABC[1:8:2]} = _mm256_shuffle_epi32(vacc${ABC[0:8]}, _MM_SHUFFLE(3, 3, 1, 1));
144        const __m256i vacc${ABC[9:16:2]} = _mm256_shuffle_epi32(vacc${ABC[8:16]}, _MM_SHUFFLE(3, 3, 1, 1));
145
146        const __m256i vprod${ABC[0:8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[0:8]}, vmultiplier), vrounding);
147        const __m256i vprod${ABC[1:8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[1:8:2]}, vmultiplier), vrounding);
148        const __m256i vprod${ABC[8:16:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[8:16]}, vmultiplier), vrounding);
149        const __m256i vprod${ABC[9:16:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[9:16:2]}, vmultiplier), vrounding);
150
151        const __m256i vq31prod${ABC[0:8:2]} = _mm256_srli_epi64(vprod${ABC[0:8:2]}, 31);
152        const __m256i vq31prod${ABC[1:8:2]} = _mm256_add_epi64(vprod${ABC[1:8:2]}, vprod${ABC[1:8:2]});
153        const __m256i vq31prod${ABC[8:16:2]} = _mm256_srli_epi64(vprod${ABC[8:16:2]}, 31);
154        const __m256i vq31prod${ABC[9:16:2]} = _mm256_add_epi64(vprod${ABC[9:16:2]}, vprod${ABC[9:16:2]});
155
156        const __m256i vq31prod${ABC[0:8]} = _mm256_blend_epi16(vq31prod${ABC[0:8:2]}, vq31prod${ABC[1:8:2]}, 0xCC);
157        const __m256i vq31prod${ABC[8:16]} = _mm256_blend_epi16(vq31prod${ABC[8:16:2]}, vq31prod${ABC[9:16:2]}, 0xCC);
158
159        const __m256i vremainder_mask = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_mask));
160        const __m256i vrem${ABC[0:8]} =
161          _mm256_add_epi32(_mm256_and_si256(vq31prod${ABC[0:8]}, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${ABC[0:8]}));
162        const __m256i vrem${ABC[8:16]} =
163          _mm256_add_epi32(_mm256_and_si256(vq31prod${ABC[8:16]}, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${ABC[8:16]}));
164
165        const __m256i vremainder_threshold = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_threshold));
166        const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
167        vacc${ABC[0:8]} =
168          _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${ABC[0:8]}, vshift), _mm256_cmpgt_epi32(vrem${ABC[0:8]}, vremainder_threshold));
169        vacc${ABC[8:16]} =
170          _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${ABC[8:16]}, vshift), _mm256_cmpgt_epi32(vrem${ABC[8:16]}, vremainder_threshold));
171
172        const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
173        __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point);
174        __m128i vout${ABC[8:16]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[8:16]}), _mm256_extracti128_si256(vacc${ABC[8:16]}, 1)), voutput_zero_point);
175
176        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min);
177        const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max);
178        vout${ABC[0:8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[0:8]}, voutput_min), voutput_max);
179        vout${ABC[8:16]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[8:16]}, voutput_min), voutput_max);
180
181        __m128i vout${ABC[0:16]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[8:16]});
182
183        $if CHANNEL_TILE > 16:
184          if XNN_LIKELY(c >= 16) {
185            _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
186            output += 16;
187            c -= 16;
188          } else {
189            if (c & 8) {
190              _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
191              vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
192              output += 8;
193            }
194            if (c & 4) {
195              *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]});
196              vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
197              output += 4;
198            }
199            if (c & 2) {
200              *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0);
201              vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
202              output += 2;
203            }
204            if (c & 1) {
205              *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
206              output += 1;
207            }
208            c = 0;
209          }
210        $else:
211          if (c & 8) {
212            _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
213            vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
214            output += 8;
215          }
216          if (c & 4) {
217            *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]});
218            vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
219            output += 4;
220          }
221          if (c & 2) {
222            *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0);
223            vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
224            output += 2;
225          }
226          if (c & 1) {
227            *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
228            output += 1;
229          }
230      }${" while (c != 0);" if CHANNEL_TILE > 16 else ""}
231    }
232
233    output = (int8_t*) ((uintptr_t) output + output_increment);
234  } while (--output_width != 0);
235}
236