1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // output_msa.h: optimized MSA specializations of the templates in output.h.
16 
17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
18 #define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
19 
20 #include "output.h"
21 
22 #include <msa.h>
23 
24 namespace gemmlowp {
25 
26 template <>
27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28                                  RegBufferInt32<4>> {
29   typedef RegBufferInt32<4> InputType;
30   typedef RegBufferUint8<4> OutputType;
31 
32   typedef OutputStageSaturatingCastToUint8 OutputStage;
33 
34   OutputStageEvalBufferImpl(const OutputStage&) {}
35 
36   OutputType Eval(InputType input) const {
37     OutputType output;
38     // Signed saturate each 32-bit element to 9 bits
39     // (this takes full care of non-negative elements).
40     v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
41     // Pack every 32-bit element into 16 bits.
42     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
43         reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
44     // Detect negative elements with arithmetic shift right (we
45     // get a 16-bit mask of all zeroes or all ones for every element).
46     v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
47     // Zero out negative elements.
48     signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
49         reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
50     // Pack every element into 8 bits.
51     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
52         reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
53     // Return 4 uint8_t elements as uint32_t.
54     output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
55     return output;
56   }
57 };
58 
59 template <>
60 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
61                                  RegBufferInt32<8>> {
62   typedef RegBufferInt32<8> InputType;
63   typedef RegBufferUint8<8> OutputType;
64 
65   typedef OutputStageSaturatingCastToUint8 OutputStage;
66 
67   OutputStageEvalBufferImpl(const OutputStage&) {}
68 
69   OutputType Eval(InputType input) const {
70     OutputType output;
71     // Signed saturate each 32-bit element to 9 bits
72     // (this takes full care of non-negative elements).
73     v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8);
74     v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8);
75     // Pack every 32-bit element into 16 bits,
76     // combining all 8 elements into one vector.
77     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
78         reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
79     // Detect negative elements with arithmetic shift right (we
80     // get a 16-bit mask of all zeroes or all ones for every element).
81     v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
82     // Zero out negative elements.
83     signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
84         reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
85     // Pack every element into 8 bits.
86     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
87         reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
88     // Return 8 uint8_t elements as 2 uint32_t's.
89     output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
90     output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
91     return output;
92   }
93 };
94 
95 #define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3)                     \
96   {                                                                          \
97     v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8);                              \
98     v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8);                              \
99     v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8);                              \
100     v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8);                              \
101     tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
102         reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)));      \
103     tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
104         reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2)));      \
105     v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15);  \
106     v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15);  \
107     signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
108         reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
109     signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
110         reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
111     signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b(                  \
112         reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0)));  \
113     out = reinterpret_cast<v16i8>(signs0);                                   \
114   }
115 
116 template <>
117 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
118                                  RegBufferInt32<16>> {
119   typedef RegBufferInt32<16> InputType;
120   typedef RegBufferUint8<16> OutputType;
121 
122   typedef OutputStageSaturatingCastToUint8 OutputStage;
123 
124   OutputStageEvalBufferImpl(const OutputStage&) {}
125 
126   OutputType Eval(InputType input) const {
127     OutputType output;
128     GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
129                             input.reg[2], input.reg[3]);
130     return output;
131   }
132 };
133 
134 template <>
135 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
136                                  RegBufferInt32<32>> {
137   typedef RegBufferInt32<32> InputType;
138   typedef RegBufferUint8<32> OutputType;
139 
140   typedef OutputStageSaturatingCastToUint8 OutputStage;
141 
142   OutputStageEvalBufferImpl(const OutputStage&) {}
143 
144   OutputType Eval(InputType input) const {
145     OutputType output;
146     GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
147                             input.reg[2], input.reg[3]);
148     GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5],
149                             input.reg[6], input.reg[7]);
150     return output;
151   }
152 };
153 
154 #undef GEMMLOWP_MIPS_SAT_U8_16
155 
156 template <>
157 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
158                                  RegBufferInt32<4>> {
159   typedef RegBufferInt32<4> InputType;
160   typedef RegBufferInt16<4> OutputType;
161 
162   typedef OutputStageSaturatingCastToInt16 OutputStage;
163 
164   OutputStageEvalBufferImpl(const OutputStage&) {}
165 
166   OutputType Eval(InputType input) const {
167     OutputType output;
168     // Signed saturate each 32-bit element to 16 bits.
169     v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
170         input.reg[0], 15));
171     output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
172     output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
173     output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
174     output.reg[3] = __builtin_msa_copy_s_h(tmp, 6);
175     return output;
176   }
177 };
178 
179 #define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1)                         \
180   {                                                                    \
181     v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15);                       \
182     v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15);                       \
183     out = __builtin_msa_pckev_h(                                       \
184         reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
185   }
186 
187 template <>
188 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
189                                  RegBufferInt32<8>> {
190   typedef RegBufferInt32<8> InputType;
191   typedef RegBufferInt16<8> OutputType;
192 
193   typedef OutputStageSaturatingCastToInt16 OutputStage;
194 
195   OutputStageEvalBufferImpl(const OutputStage&) {}
196 
197   OutputType Eval(InputType input) const {
198     OutputType output;
199     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
200     return output;
201   }
202 };
203 
204 template <>
205 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
206                                  RegBufferInt32<16>> {
207   typedef RegBufferInt32<16> InputType;
208   typedef RegBufferInt16<16> OutputType;
209 
210   typedef OutputStageSaturatingCastToInt16 OutputStage;
211 
212   OutputStageEvalBufferImpl(const OutputStage&) {}
213 
214   OutputType Eval(InputType input) const {
215     OutputType output;
216     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
217     GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
218     return output;
219   }
220 };
221 
222 template <>
223 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
224                                  RegBufferInt32<32>> {
225   typedef RegBufferInt32<32> InputType;
226   typedef RegBufferInt16<32> OutputType;
227 
228   typedef OutputStageSaturatingCastToInt16 OutputStage;
229 
230   OutputStageEvalBufferImpl(const OutputStage&) {}
231 
232   OutputType Eval(InputType input) const {
233     OutputType output;
234     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
235     GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
236     GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]);
237     GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]);
238     return output;
239   }
240 };
241 
242 #undef GEMMLOWP_MIPS_SAT_I16_8
243 
244 template <typename DstType>
245 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
246   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
247                   int col) {
248     if (DstType::kOrder == MapOrder::ColMajor) {
249       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
250     } else {
251       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
252       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
253       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
254       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
255     }
256   }
257 };
258 
259 template <typename DstType>
260 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
261   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
262                   int col) {
263     if (DstType::kOrder == MapOrder::ColMajor) {
264       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
265       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
266     } else {
267       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
268       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
269       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
270       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
271       *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
272       *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
273       *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
274       *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
275     }
276   }
277 };
278 
279 template <typename DstType>
280 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
281   static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
282                   int col) {
283     *dst->data(row + 0, col) = src.buf.reg[0];
284     *dst->data(row + 1, col) = src.buf.reg[1];
285     *dst->data(row + 2, col) = src.buf.reg[2];
286     *dst->data(row + 3, col) = src.buf.reg[3];
287   }
288 };
289 
290 template <typename DstType>
291 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
292   static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
293                   int col) {
294     if (DstType::kOrder == MapOrder::ColMajor) {
295       StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
296     } else {
297       *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0);
298       *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1);
299       *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2);
300       *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3);
301       *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4);
302       *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5);
303       *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6);
304       *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7);
305     }
306   }
307 };
308 
309 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
310   RegBlockInt32<4, 4> result;
311   v4i32 tmp0, tmp1;
312   tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]);
313   tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]);
314   result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
315       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
316   result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
317       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
318   tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]);
319   tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]);
320   result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
321       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
322   result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
323       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
324   return result;
325 }
326 
327 template <typename DstType>
328 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
329   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
330                   int col) {
331     if (DstType::kOrder == MapOrder::ColMajor) {
332       for (int i = 0; i < 4; i++) {
333         StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
334       }
335     } else {
336       const auto transpose = Transpose(src);
337       for (int i = 0; i < 4; i++) {
338         StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
339       }
340     }
341   }
342 };
343 
344 template <typename DstType>
345 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
346   static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
347                   int col) {
348     std::int16_t buf[16];
349     StoreInt16x8(buf + 0, src.buf.reg[0]);
350     StoreInt16x8(buf + 8, src.buf.reg[1]);
351     for (int i = 0; i < 4; i++) {
352       for (int j = 0; j < 4; j++) {
353         *dst->data(row + i, col + j) = buf[i + 4 * j];
354       }
355     }
356   }
357 };
358 
359 template <typename DstType>
360 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
361   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
362                   int col) {
363     if (DstType::kOrder == MapOrder::ColMajor) {
364       for (int i = 0; i < 4; i++) {
365         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
366         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
367       }
368     } else {
369       RegBlockInt32<4, 4> top;
370       top.buf.reg[0] = src.buf.reg[0];
371       top.buf.reg[1] = src.buf.reg[2];
372       top.buf.reg[2] = src.buf.reg[4];
373       top.buf.reg[3] = src.buf.reg[6];
374       const auto transpose_top = Transpose(top);
375       for (int i = 0; i < 4; i++) {
376         StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
377       }
378       RegBlockInt32<4, 4> bottom;
379       bottom.buf.reg[0] = src.buf.reg[1];
380       bottom.buf.reg[1] = src.buf.reg[3];
381       bottom.buf.reg[2] = src.buf.reg[5];
382       bottom.buf.reg[3] = src.buf.reg[7];
383       const auto transpose_bottom = Transpose(bottom);
384       for (int i = 0; i < 4; i++) {
385         StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
386       }
387     }
388   }
389 };
390 
391 template <typename DstType>
392 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
393   static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
394                   int col) {
395     if (DstType::kOrder == MapOrder::ColMajor) {
396       for (int i = 0; i < 4; i++) {
397         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
398       }
399     } else {
400       std::int16_t buf[32];
401       StoreInt16x8(buf + 0, src.buf.reg[0]);
402       StoreInt16x8(buf + 8, src.buf.reg[1]);
403       StoreInt16x8(buf + 16, src.buf.reg[2]);
404       StoreInt16x8(buf + 24, src.buf.reg[3]);
405       for (int i = 0; i < 8; i++) {
406         for (int j = 0; j < 4; j++) {
407           *dst->data(row + i, col + j) = buf[i + 8 * j];
408         }
409       }
410     }
411   }
412 };
413 
414 template <typename DstType>
415 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
416   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
417                   int col) {
418     if (DstType::kOrder == MapOrder::ColMajor) {
419       for (int i = 0; i < 8; i++) {
420         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
421         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
422       }
423     } else {
424       RegBlockInt32<4, 4> top_left;
425       top_left.buf.reg[0] = src.buf.reg[0];
426       top_left.buf.reg[1] = src.buf.reg[2];
427       top_left.buf.reg[2] = src.buf.reg[4];
428       top_left.buf.reg[3] = src.buf.reg[6];
429       const auto transpose_top_left = Transpose(top_left);
430       for (int i = 0; i < 4; i++) {
431         StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
432       }
433       RegBlockInt32<4, 4> bottom_left;
434       bottom_left.buf.reg[0] = src.buf.reg[1];
435       bottom_left.buf.reg[1] = src.buf.reg[3];
436       bottom_left.buf.reg[2] = src.buf.reg[5];
437       bottom_left.buf.reg[3] = src.buf.reg[7];
438       const auto transpose_bottom_left = Transpose(bottom_left);
439       for (int i = 0; i < 4; i++) {
440         StoreInt32x4(dst->data(row + 4 + i, col),
441                      transpose_bottom_left.buf.reg[i]);
442       }
443       RegBlockInt32<4, 4> top_right;
444       top_right.buf.reg[0] = src.buf.reg[8];
445       top_right.buf.reg[1] = src.buf.reg[10];
446       top_right.buf.reg[2] = src.buf.reg[12];
447       top_right.buf.reg[3] = src.buf.reg[14];
448       const auto transpose_top_right = Transpose(top_right);
449       for (int i = 0; i < 4; i++) {
450         StoreInt32x4(dst->data(row + i, col + 4),
451                      transpose_top_right.buf.reg[i]);
452       }
453       RegBlockInt32<4, 4> bottom_right;
454       bottom_right.buf.reg[0] = src.buf.reg[9];
455       bottom_right.buf.reg[1] = src.buf.reg[11];
456       bottom_right.buf.reg[2] = src.buf.reg[13];
457       bottom_right.buf.reg[3] = src.buf.reg[15];
458       const auto transpose_bottom_right = Transpose(bottom_right);
459       for (int i = 0; i < 4; i++) {
460         StoreInt32x4(dst->data(row + 4 + i, col + 4),
461                      transpose_bottom_right.buf.reg[i]);
462       }
463     }
464   }
465 };
466 
467 template <typename DstType>
468 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
469   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
470                   int col) {
471     if (DstType::kOrder == MapOrder::ColMajor) {
472       for (int i = 0; i < 8; i++) {
473         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
474       }
475     } else {
476       // top-left 4x4
477       v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
478           src.buf.reg[0]));
479       v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
480           src.buf.reg[2]));
481       v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
482       v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
483       // top-right 4x4
484       v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
485           src.buf.reg[4]));
486       v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
487           src.buf.reg[6]));
488       v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
489       v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
490       // bottom-left 4x4
491       v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
492           src.buf.reg[0]));
493       v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
494           src.buf.reg[2]));
495       v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
496       v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
497       // bottom-right 4x4
498       v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
499           src.buf.reg[4]));
500       v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
501           src.buf.reg[6]));
502       v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
503       v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
504 
505       StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
506           __builtin_msa_ilvr_d(u2, u0)));
507       StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
508           __builtin_msa_ilvl_d(u2, u0)));
509       StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
510           __builtin_msa_ilvr_d(u3, u1)));
511       StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
512           __builtin_msa_ilvl_d(u3, u1)));
513       StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
514           __builtin_msa_ilvr_d(u6, u4)));
515       StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
516           __builtin_msa_ilvl_d(u6, u4)));
517       StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
518           __builtin_msa_ilvr_d(u7, u5)));
519       StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
520           __builtin_msa_ilvl_d(u7, u5)));
521     }
522   }
523 };
524 
525 template <typename DstType>
526 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
527   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
528                   int col) {
529     if (DstType::kOrder == MapOrder::ColMajor) {
530       *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
531       *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
532       *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
533       *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
534     } else {
535       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
536     }
537   }
538 };
539 
540 template <typename DstType>
541 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
542   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
543                   int col) {
544     const std::uint32_t src_reg = src.buf.reg[0];
545     for (int i = 0; i < 4; i++) {
546       *dst->data(row + i, col) = (src_reg >> (8 * i));
547     }
548   }
549 };
550 
551 template <typename DstType>
552 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
553   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
554                   int col) {
555     for (int i = 0; i < 4; i++) {
556       *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
557     }
558     for (int i = 0; i < 4; i++) {
559       *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
560     }
561   }
562 };
563 
564 template <typename DstType>
565 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
566   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
567                   int col) {
568     for (int i = 0; i < 4; i++) {
569       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
570     }
571   }
572 };
573 
574 template <typename DstType>
575 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
576   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
577                   int col) {
578     std::uint8_t buf[16];
579     StoreUint8x16(buf, src.buf.reg[0]);
580     for (int c = 0; c < 4; c++) {
581       for (int r = 0; r < 4; r++) {
582         *dst->data(row + r, col + c) = buf[r + 4 * c];
583       }
584     }
585   }
586 };
587 
588 template <typename DstType>
589 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
590   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
591                   int col) {
592     std::uint8_t buf[32];
593     StoreUint8x16(buf, src.buf.reg[0]);
594     StoreUint8x16(buf + 16, src.buf.reg[1]);
595     for (int c = 0; c < 4; c++) {
596       for (int r = 0; r < 8; r++) {
597         *dst->data(row + r, col + c) = buf[r + 8 * c];
598       }
599     }
600   }
601 };
602 
603 template <typename DstType>
604 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
605   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
606                   int col) {
607     std::uint8_t buf[64];
608     StoreUint8x16(buf, src.buf.reg[0]);
609     StoreUint8x16(buf + 16, src.buf.reg[1]);
610     StoreUint8x16(buf + 32, src.buf.reg[2]);
611     StoreUint8x16(buf + 48, src.buf.reg[3]);
612     for (int c = 0; c < 8; c++) {
613       for (int r = 0; r < 8; r++) {
614         *dst->data(row + r, col + c) = buf[r + 8 * c];
615       }
616     }
617   }
618 };
619 
620 }  // namespace gemmlowp
621 
622 #endif  // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
623