1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // pack_msa.h: optimized MSA specializations of the templates in pack.h.
16 
17 #ifndef GEMMLOWP_INTERNAL_PACK_MSA_H_
18 #define GEMMLOWP_INTERNAL_PACK_MSA_H_
19 
20 #include "pack.h"
21 
22 #include <msa.h>
23 
24 namespace gemmlowp {
25 
26 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
27     WidthMajorUint8SideMap;
28 
29 template <int Cells>
30 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
31 
32 template <int Cells>
33 class PackingRegisterBlock<
34     WidthMajorUint8SideMap,
35     PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
36     : public PackingRegisterBlockBase<
37           WidthMajorUint8SideMap,
38           PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
39  public:
40   typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
41   typedef typename KernelSideFormat::Cell CellFormat;
42   static constexpr int kCells = KernelSideFormat::kCells;
43   static const int kCellWidth = CellFormat::kWidth;
44   static const int kKernelWidth = CellFormat::kWidth * kCells;
45   static const int kCellDepth = CellFormat::kDepth;
46   static const int kCellSize = CellFormat::kSize;
47 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)48   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
49     std::uint8_t* dst_ptr = dst->current_data();
50     const std::uint8_t* const src_ptr = this->complete_src_.data();
51     const int stride = this->complete_src_.stride();
52     // Load source WidthMajor data
53     v16i8 src_lines[4 * kCells];
54     for (int i = 0; i < 4 * kCells; i++) {
55       src_lines[i] = __builtin_msa_ld_b(
56           const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
57     }
58     // Reorder the data within registers to make DepthMajor 4x2 cells
59     v16i8 src_lines_intertwined_2x[2 * kCells][2];
60     for (int i = 0; i < kCells; i++) {
61       src_lines_intertwined_2x[2 * i][0] =
62           __builtin_msa_ilvr_b(src_lines[4 * i + 2], src_lines[4 * i]);
63       src_lines_intertwined_2x[2 * i][1] =
64           __builtin_msa_ilvl_b(src_lines[4 * i + 2], src_lines[4 * i]);
65       src_lines_intertwined_2x[2 * i + 1][0] =
66           __builtin_msa_ilvr_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
67       src_lines_intertwined_2x[2 * i + 1][1] =
68           __builtin_msa_ilvl_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
69     }
70     v16i8 src_lines_intertwined_4x[2 * kCells][2];
71     for (int i = 0; i < kCells; i++) {
72       src_lines_intertwined_4x[2 * i][0] =
73           __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][0],
74                                src_lines_intertwined_2x[2 * i][0]);
75       src_lines_intertwined_4x[2 * i][1] =
76           __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][0],
77                                src_lines_intertwined_2x[2 * i][0]);
78       src_lines_intertwined_4x[2 * i + 1][0] =
79           __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][1],
80                                src_lines_intertwined_2x[2 * i][1]);
81       src_lines_intertwined_4x[2 * i + 1][1] =
82           __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][1],
83                                src_lines_intertwined_2x[2 * i][1]);
84     }
85     // Store the resulting DepthMajor 4x2 cells in the destination packed block
86     for (int outer = 0; outer < 2; outer++) {
87       for (int inner = 0; inner < 2; inner++) {
88         if (kCells % 2 == 0) {
89           for (int cell = 0; cell < kCells; cell += 2) {
90             v2i64 tmp = __builtin_msa_ilvr_d(
91                 reinterpret_cast<v2i64>(
92                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
93                 reinterpret_cast<v2i64>(
94                     src_lines_intertwined_4x[2 * cell + outer][inner]));
95             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
96             dst_ptr += 16;
97           }
98           for (int cell = 0; cell < kCells; cell += 2) {
99             v2i64 tmp = __builtin_msa_ilvl_d(
100                 reinterpret_cast<v2i64>(
101                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
102                 reinterpret_cast<v2i64>(
103                     src_lines_intertwined_4x[2 * cell + outer][inner]));
104             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
105             dst_ptr += 16;
106           }
107         } else {
108           // Store even number of low vector halves.
109           for (int cell = 0; cell < kCells - 1; cell += 2) {
110             v2i64 tmp = __builtin_msa_ilvr_d(
111                 reinterpret_cast<v2i64>(
112                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
113                 reinterpret_cast<v2i64>(
114                     src_lines_intertwined_4x[2 * cell + outer][inner]));
115             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
116             dst_ptr += 16;
117           }
118           // Store last low half and first high half.
119           v2i64 tmp = reinterpret_cast<v2i64>(
120               src_lines_intertwined_4x[2 * 0 + outer][inner]);
121           tmp = __builtin_msa_insve_d(
122               tmp, 0,
123               reinterpret_cast<v2i64>(
124                   src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
125           __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
126           dst_ptr += 16;
127           // Store even number of high vector halves.
128           for (int cell = 1; cell < kCells; cell += 2) {
129             v2i64 tmp = __builtin_msa_ilvl_d(
130                 reinterpret_cast<v2i64>(
131                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
132                 reinterpret_cast<v2i64>(
133                     src_lines_intertwined_4x[2 * cell + outer][inner]));
134             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
135             dst_ptr += 16;
136           }
137         }
138       }
139     }
140     // Compute sums across the depth dimension
141     v8i16 sums_of_2_cells[kCells][4];
142     const v16i8 zeroes = __builtin_msa_ldi_b(0);
143     for (int outer = 0; outer < 2; outer++) {
144       for (int inner = 0; inner < 2; inner++) {
145         int i = 2 * outer + inner;
146         for (int cell = 0; cell < kCells; cell++) {
147           v8i16 tmp0 = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(
148               zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
149           v8i16 tmp1 = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(
150               zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
151           sums_of_2_cells[cell][i] = __builtin_msa_addv_h(tmp0, tmp1);
152         }
153       }
154     }
155     v4i32 sums_of_4_cells[kCells][4];
156     for (int i = 0; i < 4; i++) {
157       for (int cell = 0; cell < kCells; cell++) {
158         v4i32 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(
159             reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
160         v4i32 tmp1 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(
161             reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
162         sums_of_4_cells[cell][i] = __builtin_msa_addv_w(tmp0, tmp1);
163       }
164     }
165     // Update the sums_of_each_slice vector
166     for (int cell = 0; cell < kCells; cell++) {
167       v4i32 s01 = __builtin_msa_addv_w(sums_of_4_cells[cell][0],
168                                        sums_of_4_cells[cell][1]);
169       v4i32 s23 = __builtin_msa_addv_w(sums_of_4_cells[cell][2],
170                                        sums_of_4_cells[cell][3]);
171       v4i32 s = __builtin_msa_addv_w(s01, s23);
172       std::int32_t* sums_of_each_slice_ptr =
173           dst->sums_of_each_slice() + start_width + 4 * cell;
174       v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
175       tmp = __builtin_msa_addv_w(tmp, s);
176       __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
177     }
178     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
179   }
180 };
181 
182 template <int Cells>
183 using WidthMajorSideFormatNCells4x2 =
184     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
185 
186 template <int Cells>
187 class PackingRegisterBlock<
188     WidthMajorUint8SideMap,
189     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
190     : public PackingRegisterBlockBase<
191           WidthMajorUint8SideMap,
192           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
193  public:
194   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
195   typedef typename KernelSideFormat::Cell CellFormat;
196   static constexpr int kCells = KernelSideFormat::kCells;
197   static const int kCellWidth = CellFormat::kWidth;
198   static const int kKernelWidth = CellFormat::kWidth * kCells;
199   static const int kCellDepth = CellFormat::kDepth;
200   static const int kCellSize = CellFormat::kSize;
201 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)202   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
203     std::uint8_t* dst_ptr = dst->current_data();
204     const std::uint8_t* src_ptr = this->complete_src_.data();
205     const int stride = this->complete_src_.stride();
206     // Load source WidthMajor data
207     v8i16 src_lines[kCells * 4];
208     for (int i = 0; i < kCells; i++) {
209 #define GEMMLOWP_UNROLLED_LOOP_ITER(k)                           \
210   src_lines[4 * i + k] =                                         \
211       __builtin_msa_ld_h(const_cast<std::uint8_t*>(src_ptr), 0); \
212   src_ptr += stride;
213 
214       GEMMLOWP_UNROLLED_LOOP_ITER(0)
215       GEMMLOWP_UNROLLED_LOOP_ITER(1)
216       GEMMLOWP_UNROLLED_LOOP_ITER(2)
217       GEMMLOWP_UNROLLED_LOOP_ITER(3)
218 
219 #undef GEMMLOWP_UNROLLED_LOOP_ITER
220     }
221     // Reorder the data within registers to make WidthMajor 4x2 cells
222     v8i16 src_lines_intertwined_2x[2 * kCells][2];
223     for (int i = 0; i < kCells; i++) {
224       src_lines_intertwined_2x[2 * i][0] =
225           __builtin_msa_ilvr_h(src_lines[4 * i + 2], src_lines[4 * i]);
226       src_lines_intertwined_2x[2 * i][1] =
227           __builtin_msa_ilvl_h(src_lines[4 * i + 2], src_lines[4 * i]);
228       src_lines_intertwined_2x[2 * i + 1][0] =
229           __builtin_msa_ilvr_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
230       src_lines_intertwined_2x[2 * i + 1][1] =
231           __builtin_msa_ilvl_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
232     }
233     v8i16 src_lines_intertwined_4x[2 * kCells][2];
234     for (int i = 0; i < kCells; i++) {
235       src_lines_intertwined_4x[2 * i][0] =
236           __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][0],
237                                src_lines_intertwined_2x[2 * i][0]);
238       src_lines_intertwined_4x[2 * i][1] =
239           __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][0],
240                                src_lines_intertwined_2x[2 * i][0]);
241       src_lines_intertwined_4x[2 * i + 1][0] =
242           __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][1],
243                                src_lines_intertwined_2x[2 * i][1]);
244       src_lines_intertwined_4x[2 * i + 1][1] =
245           __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][1],
246                                src_lines_intertwined_2x[2 * i][1]);
247     }
248     // Store the resulting WidthMajor 4x2 cells in the destination packed block
249     for (int outer = 0; outer < 2; outer++) {
250       for (int inner = 0; inner < 2; inner++) {
251         if (kCells % 2 == 0) {
252           for (int cell = 0; cell < kCells; cell += 2) {
253             v2i64 tmp = __builtin_msa_ilvr_d(
254                 reinterpret_cast<v2i64>(
255                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
256                 reinterpret_cast<v2i64>(
257                     src_lines_intertwined_4x[2 * cell + outer][inner]));
258             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
259             dst_ptr += 16;
260           }
261           for (int cell = 0; cell < kCells; cell += 2) {
262             v2i64 tmp = __builtin_msa_ilvl_d(
263                 reinterpret_cast<v2i64>(
264                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
265                 reinterpret_cast<v2i64>(
266                     src_lines_intertwined_4x[2 * cell + outer][inner]));
267             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
268             dst_ptr += 16;
269           }
270         } else {
271           // Store even number of low vector halves.
272           for (int cell = 0; cell < kCells - 1; cell += 2) {
273             v2i64 tmp = __builtin_msa_ilvr_d(
274                 reinterpret_cast<v2i64>(
275                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
276                 reinterpret_cast<v2i64>(
277                     src_lines_intertwined_4x[2 * cell + outer][inner]));
278             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
279             dst_ptr += 16;
280           }
281           // Store last low half and first high half.
282           v2i64 tmp = reinterpret_cast<v2i64>(
283               src_lines_intertwined_4x[2 * 0 + outer][inner]);
284           tmp = __builtin_msa_insve_d(
285               tmp, 0,
286               reinterpret_cast<v2i64>(
287                   src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
288           __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
289           dst_ptr += 16;
290           // Store even number of high vector halves.
291           for (int cell = 1; cell < kCells; cell += 2) {
292             v2i64 tmp = __builtin_msa_ilvl_d(
293                 reinterpret_cast<v2i64>(
294                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
295                 reinterpret_cast<v2i64>(
296                     src_lines_intertwined_4x[2 * cell + outer][inner]));
297             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
298             dst_ptr += 16;
299           }
300         }
301       }
302     }
303     // Compute sums across the depth dimension
304     v8i16 sums_of_2[kCells][4];
305     for (int outer = 0; outer < 2; outer++) {
306       for (int inner = 0; inner < 2; inner++) {
307         int i = 2 * outer + inner;
308         for (int cell = 0; cell < kCells; cell++) {
309           sums_of_2[cell][i] = reinterpret_cast<v8i16>(__builtin_msa_hadd_u_h(
310               reinterpret_cast<v16u8>(
311                   src_lines_intertwined_4x[2 * cell + outer][inner]),
312               reinterpret_cast<v16u8>(
313                   src_lines_intertwined_4x[2 * cell + outer][inner])));
314         }
315       }
316     }
317     v8i16 sums_of_4[kCells][2];
318     for (int i = 0; i < 2; i++) {
319       for (int cell = 0; cell < kCells; cell++) {
320         sums_of_4[cell][i] = __builtin_msa_addv_h(sums_of_2[cell][2 * i],
321                                                   sums_of_2[cell][2 * i + 1]);
322       }
323     }
324     v8i16 sums_of_8[kCells];
325     for (int cell = 0; cell < kCells; cell++) {
326       sums_of_8[cell] =
327           __builtin_msa_addv_h(sums_of_4[cell][0], sums_of_4[cell][1]);
328     }
329 
330     v4i32 sums_of_16[kCells];
331     const v8i16 zeroes = __builtin_msa_ldi_h(0);
332     for (int cell = 0; cell < kCells; cell++) {
333       sums_of_16[cell] = reinterpret_cast<v4i32>(
334           __builtin_msa_ilvr_h(zeroes, sums_of_8[cell]));
335       v8i16 tmp = __builtin_msa_ilvl_h(zeroes, sums_of_8[cell]);
336       sums_of_16[cell] =
337           __builtin_msa_addv_w(sums_of_16[cell], reinterpret_cast<v4i32>(tmp));
338     }
339     // Update the sums_of_each_slice vector
340     for (int cell = 0; cell < kCells; cell++) {
341       std::int32_t* sums_of_each_slice_ptr =
342           dst->sums_of_each_slice() + start_width + 4 * cell;
343       v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
344       tmp = __builtin_msa_addv_w(tmp, sums_of_16[cell]);
345       __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
346     }
347     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
348   }
349 };
350 
351 template <int Width>
352 using Int8FastKernelFormat =
353     KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
354 
355 template <int Width>
356 class PackingRegisterBlock<WidthMajorUint8SideMap,
357                            PackedSideBlock<Int8FastKernelFormat<Width>>>
358     : public PackingRegisterBlockBase<
359           WidthMajorUint8SideMap,
360           PackedSideBlock<Int8FastKernelFormat<Width>>> {
361  public:
362   static_assert(Width == 2 || Width == 4, "");
363   typedef Int8FastKernelFormat<Width> KernelSideFormat;
364   typedef typename KernelSideFormat::Cell CellFormat;
365   static const int kCells = KernelSideFormat::kCells;
366   static const int kCellWidth = CellFormat::kWidth;
367   static const int kKernelWidth = CellFormat::kWidth * kCells;
368   static const int kCellDepth = CellFormat::kDepth;
369   static const int kCellSize = CellFormat::kSize;
370 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)371   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
372     std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
373     std::uint8_t* dst_ptr = dst->current_data();
374     const std::uint8_t* const src_ptr = this->complete_src_.data();
375     const int stride = this->complete_src_.stride();
376     // Load source WidthMajor data.
377     v16i8 src_lines[Width];
378     for (int i = 0; i < Width; i++) {
379       src_lines[i] = __builtin_msa_ld_b(
380           const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
381     }
382     for (int i = 0; i < Width; i++) {
383       // Subtract 128 by inverting bit 7.
384       src_lines[i] = reinterpret_cast<v16i8>(
385           __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7));
386     }
387     for (int i = 0; i < Width; i++) {
388       __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0);
389     }
390     v8i16 sums2[Width];
391     for (int i = 0; i < Width; i++) {
392       sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]);
393     }
394     v4i32 sums4_wide[Width];
395     for (int i = 0; i < Width; i++) {
396       sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]);
397     }
398     v8i16 sums4[Width / 2];
399     for (int i = 0; i < Width / 2; i++) {
400       sums4[i] = __builtin_msa_pckev_h(
401           reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]),
402           reinterpret_cast<v8i16>(sums4_wide[2 * i]));
403     }
404     v4i32 sums8_wide[Width / 2];
405     for (int i = 0; i < Width / 2; i++) {
406       sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]);
407     }
408     if (Width == 4) {
409       v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0);
410       v8i16 sums8 = __builtin_msa_pckev_h(
411           reinterpret_cast<v8i16>(sums8_wide[1]),
412           reinterpret_cast<v8i16>(sums8_wide[0]));
413       v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8);
414       sum = __builtin_msa_addv_w(sum, sums16);
415       __builtin_msa_st_w(sum, sums_ptr, 0);
416     } else {
417       assert(Width == 2);
418       std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] };
419       v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]);
420       sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0);
421       sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2);
422       sums_ptr[0] = sum[0];
423       sums_ptr[1] = sum[1];
424     }
425     dst->seek_forward_n_cells(1);
426   }
427 };
428 
429 }  // namespace gemmlowp
430 
431 #endif  // GEMMLOWP_INTERNAL_PACK_MSA_H_
432