1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // pack_avx.h: optimized AVX specializations of the templates in pack.h.
16 
17 #ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_
18 #define GEMMLOWP_INTERNAL_PACK_AVX_H_
19 
20 #include <immintrin.h>
21 #include "pack.h"
22 
23 namespace gemmlowp {
24 
25 // TODO: Add DepthMajorUint8SideMap
26 
27 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
28     WidthMajorUint8SideMap;
29 
30 template <int Cells>
31 using WidthMajorSideFormatNCells4x2 =
32     KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>;
33 
34 template <int Cells>
35 class PackingRegisterBlock<
36     WidthMajorUint8SideMap,
37     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
38     : public PackingRegisterBlockBase<
39           WidthMajorUint8SideMap,
40           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
41  public:
42   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
43   typedef typename KernelSideFormat::Cell CellFormat;
44   static const int kCells = KernelSideFormat::kCells;
45   static const int kCellWidth = CellFormat::kWidth;
46   static const int kKernelWidth = CellFormat::kWidth * kCells;
47   static const int kCellDepth = CellFormat::kDepth;
48   static const int kCellSize = CellFormat::kSize;
49 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)50   void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
51     std::uint8_t *dst_ptr = dst->current_data();
52     const int width_stride = this->complete_src_.width_stride();
53     int depth_step = 16;
54 
55     __m256i one = _mm256_set1_epi16(1);
56     for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
57          cell_start_depth += depth_step) {
58       for (int cell_start_width = 0; cell_start_width < kKernelWidth;
59            cell_start_width += kCellWidth) {
60         std::int32_t *cell_sums_of_each_slice_ptr =
61             dst->sums_of_each_slice() + start_width + cell_start_width;
62         const std::uint8_t *src_data =
63             this->complete_src_.data(cell_start_width, cell_start_depth);
64 
65         __m128i xmm1 =
66             _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0]));
67         __m128i xmm2 = _mm_loadu_si128(
68             reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
69         __m128i xmm3 = _mm_loadu_si128(
70             reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
71         __m128i xmm4 = _mm_loadu_si128(
72             reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
73         __m128i xmm5 = _mm_loadu_si128(
74             reinterpret_cast<const __m128i *>(&src_data[4 * width_stride]));
75         __m128i xmm6 = _mm_loadu_si128(
76             reinterpret_cast<const __m128i *>(&src_data[5 * width_stride]));
77         __m128i xmm7 = _mm_loadu_si128(
78             reinterpret_cast<const __m128i *>(&src_data[6 * width_stride]));
79         __m128i xmm8 = _mm_loadu_si128(
80             reinterpret_cast<const __m128i *>(&src_data[7 * width_stride]));
81 
82         __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1);
83         __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2);
84         __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3);
85         __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4);
86 
87         __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2);
88         __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4);
89 
90         __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2);
91         __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4);
92 
93         __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6);
94         __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6);
95 
96         __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10);
97         __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10);
98 
99         __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8);
100         __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8);
101 
102         __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8);
103         __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8);
104 
105         __m128i xmm9 = _mm256_castsi256_si128(ymm11);
106         __m128i xmm10 = _mm256_castsi256_si128(ymm12);
107         __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1);
108         __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1);
109 
110         xmm1 = _mm256_castsi256_si128(ymm15);
111         xmm2 = _mm256_castsi256_si128(ymm16);
112         xmm3 = _mm256_extracti128_si256(ymm15, 1);
113         xmm4 = _mm256_extracti128_si256(ymm16, 1);
114 
115         _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
116         _mm_storeu_si128(
117             reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11);
118         _mm_storeu_si128(
119             reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
120             xmm10);
121         _mm_storeu_si128(
122             reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
123             xmm12);
124         _mm_storeu_si128(
125             reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]),
126             xmm1);
127         _mm_storeu_si128(
128             reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]),
129             xmm3);
130 
131         _mm_storeu_si128(
132             reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]),
133             xmm2);
134         _mm_storeu_si128(
135             reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]),
136             xmm4);
137 
138         ymm6 = _mm256_cvtepu8_epi16(xmm9);
139         ymm7 = _mm256_madd_epi16(ymm6, one);
140         __m256i sums_of_each_slice_xmm = _mm256_loadu_si256(
141             reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0]));
142         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
143 
144         ymm6 = _mm256_cvtepu8_epi16(xmm11);
145         ymm7 = _mm256_madd_epi16(ymm6, one);
146         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
147 
148         ymm6 = _mm256_cvtepu8_epi16(xmm10);
149         ymm7 = _mm256_madd_epi16(ymm6, one);
150         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
151 
152         ymm6 = _mm256_cvtepu8_epi16(xmm12);
153         ymm7 = _mm256_madd_epi16(ymm6, one);
154         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
155 
156         ymm6 = _mm256_cvtepu8_epi16(xmm1);
157         ymm7 = _mm256_madd_epi16(ymm6, one);
158         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
159 
160         ymm6 = _mm256_cvtepu8_epi16(xmm3);
161         ymm7 = _mm256_madd_epi16(ymm6, one);
162         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
163 
164         ymm6 = _mm256_cvtepu8_epi16(xmm2);
165         ymm7 = _mm256_madd_epi16(ymm6, one);
166         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
167 
168         ymm6 = _mm256_cvtepu8_epi16(xmm4);
169         ymm7 = _mm256_madd_epi16(ymm6, one);
170         sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
171 
172         _mm256_storeu_si256(
173             reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]),
174             sums_of_each_slice_xmm);
175         dst_ptr += kCellSize;
176       }
177       dst_ptr += 7 * kCellSize * kCells;
178     }
179     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
180   }
181 };
182 
183 // Pack format for 4x2 rhs format
184 template <int Cells>
185 using RhsWidthMajorSideFormatNCells4x2 =
186     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
187 
188 template <int Cells>
189 class PackingRegisterBlock<
190     WidthMajorUint8SideMap,
191     PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>>
192     : public PackingRegisterBlockBase<
193           WidthMajorUint8SideMap,
194           PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> {
195  public:
196   typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
197   typedef typename KernelSideFormat::Cell CellFormat;
198   static const int kCells = KernelSideFormat::kCells;
199   static const int kCellWidth = CellFormat::kWidth;
200   static const int kKernelWidth = CellFormat::kWidth * kCells;
201   static const int kCellDepth = CellFormat::kDepth;
202   static const int kCellSize = CellFormat::kSize;
203 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)204   void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
205     std::uint8_t *dst_ptr = dst->current_data();
206     const int width_stride = this->complete_src_.width_stride();
207     int depth_step = 8;
208 
209     __m128i one = _mm_set1_epi16(1);
210     for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
211          cell_start_depth += depth_step) {
212       for (int cell_start_width = 0; cell_start_width < kKernelWidth;
213            cell_start_width += kCellWidth) {
214         std::int32_t *cell_sums_of_each_slice_ptr =
215             dst->sums_of_each_slice() + start_width + cell_start_width;
216         const std::uint8_t *src_data =
217             this->complete_src_.data(cell_start_width, cell_start_depth);
218 
219         __m128i xmm1 =
220             _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0]));
221         __m128i xmm2 = _mm_loadl_epi64(
222             reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
223         __m128i xmm3 = _mm_loadl_epi64(
224             reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
225         __m128i xmm4 = _mm_loadl_epi64(
226             reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
227 
228         __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
229         __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
230 
231         __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
232         __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
233 
234         __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
235         __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
236 
237         _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
238         _mm_storel_epi64(
239             reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10);
240 
241         __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
242         __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
243 
244         _mm_storel_epi64(
245             reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
246             xmm11);
247         _mm_storel_epi64(
248             reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
249             xmm12);
250 
251         xmm1 = _mm_cvtepu8_epi16(xmm9);
252         xmm2 = _mm_madd_epi16(xmm1, one);
253         __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
254             reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0]));
255         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
256 
257         xmm1 = _mm_cvtepu8_epi16(xmm10);
258         xmm2 = _mm_madd_epi16(xmm1, one);
259         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
260 
261         xmm1 = _mm_cvtepu8_epi16(xmm11);
262         xmm2 = _mm_madd_epi16(xmm1, one);
263         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
264 
265         xmm1 = _mm_cvtepu8_epi16(xmm12);
266         xmm2 = _mm_madd_epi16(xmm1, one);
267         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
268 
269         _mm_storeu_si128(
270             reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]),
271             sums_of_each_slice_xmm);
272         dst_ptr += kCellSize;
273       }
274       dst_ptr += 3 * kCellSize * kCells;
275     }
276     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
277   }
278 };
279 
280 }  // namespace gemmlowp
281 
282 #endif  // GEMMLOWP_INTERNAL_PACK_AVX_H_
283