1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // output_neon.h: optimized NEON specializations of the templates in output.h.
16 
17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
19 
20 #include "output.h"
21 
22 #include <arm_neon.h>
23 
24 namespace gemmlowp {
25 
26 template <>
27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28                                  RegBufferInt32<4>> {
29   typedef RegBufferInt32<4> InputType;
30   typedef RegBufferUint8<4> OutputType;
31 
32   typedef OutputStageSaturatingCastToUint8 OutputStage;
33 
34   OutputStageEvalBufferImpl(const OutputStage&) {}
35 
36   OutputType Eval(InputType input) const {
37     OutputType output;
38     int16x4_t res_16 = vqmovn_s32(input.reg[0]);
39     uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16));
40     output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0);
41     return output;
42   }
43 };
44 
45 template <>
46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
47                                  RegBufferInt32<8>> {
48   typedef RegBufferInt32<8> InputType;
49   typedef RegBufferUint8<8> OutputType;
50 
51   typedef OutputStageSaturatingCastToUint8 OutputStage;
52 
53   OutputStageEvalBufferImpl(const OutputStage&) {}
54 
55   OutputType Eval(InputType input) const {
56     OutputType output;
57     int16x8_t res_16 =
58         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
59     output.reg[0] = vqmovun_s16(res_16);
60     return output;
61   }
62 };
63 
64 template <>
65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
66                                  RegBufferInt32<16>> {
67   typedef RegBufferInt32<16> InputType;
68   typedef RegBufferUint8<16> OutputType;
69 
70   typedef OutputStageSaturatingCastToUint8 OutputStage;
71 
72   OutputStageEvalBufferImpl(const OutputStage&) {}
73 
74   OutputType Eval(InputType input) const {
75     OutputType output;
76     int16x8_t res_16_0 =
77         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
78     int16x8_t res_16_1 =
79         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
80     output.reg[0] = vqmovun_s16(res_16_0);
81     output.reg[1] = vqmovun_s16(res_16_1);
82     return output;
83   }
84 };
85 
86 template <>
87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
88                                  RegBufferInt32<32>> {
89   typedef RegBufferInt32<32> InputType;
90   typedef RegBufferUint8<32> OutputType;
91 
92   typedef OutputStageSaturatingCastToUint8 OutputStage;
93 
94   OutputStageEvalBufferImpl(const OutputStage&) {}
95 
96   OutputType Eval(InputType input) const {
97     OutputType output;
98     int16x8_t res_16[4];
99     for (int i = 0; i < 4; i++) {
100       res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
101                                vqmovn_s32(input.reg[2 * i + 1]));
102     }
103     for (int i = 0; i < 4; i++) {
104       output.reg[i] = vqmovun_s16(res_16[i]);
105     }
106     return output;
107   }
108 };
109 
110 template <>
111 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
112                                  RegBufferInt32<4>> {
113   typedef RegBufferInt32<4> InputType;
114   typedef RegBufferInt8<4> OutputType;
115 
116   typedef OutputStageSaturatingCastToInt8 OutputStage;
117 
118   OutputStageEvalBufferImpl(const OutputStage&) {}
119 
120   OutputType Eval(InputType input) const {
121     OutputType output;
122     int16x4_t res_16 = vqmovn_s32(input.reg[0]);
123     int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16));
124     output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0);
125     return output;
126   }
127 };
128 
129 template <>
130 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
131                                  RegBufferInt32<8>> {
132   typedef RegBufferInt32<8> InputType;
133   typedef RegBufferInt8<8> OutputType;
134 
135   typedef OutputStageSaturatingCastToInt8 OutputStage;
136 
137   OutputStageEvalBufferImpl(const OutputStage&) {}
138 
139   OutputType Eval(InputType input) const {
140     OutputType output;
141     int16x8_t res_16 =
142         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
143     output.reg[0] = vqmovn_s16(res_16);
144     return output;
145   }
146 };
147 
148 template <>
149 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
150                                  RegBufferInt32<16>> {
151   typedef RegBufferInt32<16> InputType;
152   typedef RegBufferInt8<16> OutputType;
153 
154   typedef OutputStageSaturatingCastToInt8 OutputStage;
155 
156   OutputStageEvalBufferImpl(const OutputStage&) {}
157 
158   OutputType Eval(InputType input) const {
159     OutputType output;
160     int16x8_t res_16_0 =
161         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
162     int16x8_t res_16_1 =
163         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
164     output.reg[0] = vqmovn_s16(res_16_0);
165     output.reg[1] = vqmovn_s16(res_16_1);
166     return output;
167   }
168 };
169 
170 template <>
171 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
172                                  RegBufferInt32<32>> {
173   typedef RegBufferInt32<32> InputType;
174   typedef RegBufferInt8<32> OutputType;
175 
176   typedef OutputStageSaturatingCastToInt8 OutputStage;
177 
178   OutputStageEvalBufferImpl(const OutputStage&) {}
179 
180   OutputType Eval(InputType input) const {
181     OutputType output;
182     int16x8_t res_16[4];
183     for (int i = 0; i < 4; i++) {
184       res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
185                                vqmovn_s32(input.reg[2 * i + 1]));
186     }
187     for (int i = 0; i < 4; i++) {
188       output.reg[i] = vqmovn_s16(res_16[i]);
189     }
190     return output;
191   }
192 };
193 
194 template <>
195 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
196                                  RegBufferInt32<4>> {
197   typedef RegBufferInt32<4> InputType;
198   typedef RegBufferInt16<4> OutputType;
199 
200   typedef OutputStageSaturatingCastToInt16 OutputStage;
201 
202   OutputStageEvalBufferImpl(const OutputStage&) {}
203 
204   OutputType Eval(InputType input) const {
205     OutputType output;
206     output.reg[0] = vqmovn_s32(input.reg[0]);
207     return output;
208   }
209 };
210 
211 template <>
212 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
213                                  RegBufferInt32<8>> {
214   typedef RegBufferInt32<8> InputType;
215   typedef RegBufferInt16<8> OutputType;
216 
217   typedef OutputStageSaturatingCastToInt16 OutputStage;
218 
219   OutputStageEvalBufferImpl(const OutputStage&) {}
220 
221   OutputType Eval(InputType input) const {
222     OutputType output;
223     output.reg[0] =
224         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
225     return output;
226   }
227 };
228 
229 template <>
230 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
231                                  RegBufferInt32<16>> {
232   typedef RegBufferInt32<16> InputType;
233   typedef RegBufferInt16<16> OutputType;
234 
235   typedef OutputStageSaturatingCastToInt16 OutputStage;
236 
237   OutputStageEvalBufferImpl(const OutputStage&) {}
238 
239   OutputType Eval(InputType input) const {
240     OutputType output;
241     output.reg[0] =
242         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
243     output.reg[1] =
244         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
245     return output;
246   }
247 };
248 
249 template <>
250 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
251                                  RegBufferInt32<32>> {
252   typedef RegBufferInt32<32> InputType;
253   typedef RegBufferInt16<32> OutputType;
254 
255   typedef OutputStageSaturatingCastToInt16 OutputStage;
256 
257   OutputStageEvalBufferImpl(const OutputStage&) {}
258 
259   OutputType Eval(InputType input) const {
260     OutputType output;
261     output.reg[0] =
262         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
263     output.reg[1] =
264         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
265     output.reg[2] =
266         vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5]));
267     output.reg[3] =
268         vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7]));
269     return output;
270   }
271 };
272 
273 template <typename DstType>
274 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
275   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
276                   int col) {
277     if (DstType::kOrder == MapOrder::ColMajor) {
278       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
279       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
280     } else {
281       vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0);
282       vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1);
283       vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2);
284       vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3);
285       vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0);
286       vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1);
287       vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2);
288       vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3);
289     }
290   }
291 };
292 
293 template <typename DstType>
294 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
295   static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
296                   int col) {
297     if (DstType::kOrder == MapOrder::ColMajor) {
298       StoreInt16x4(dst->data(row, col), src.buf.reg[0]);
299     } else {
300       vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
301       vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
302       vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
303       vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
304     }
305   }
306 };
307 
308 template <typename DstType>
309 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
310   static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
311                   int col) {
312     if (DstType::kOrder == MapOrder::ColMajor) {
313       StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
314     } else {
315       vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
316       vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
317       vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
318       vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
319       vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4);
320       vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5);
321       vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6);
322       vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7);
323     }
324   }
325 };
326 
327 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
328   const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]);
329   const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]);
330   RegBlockInt32<4, 4> result;
331   result.buf.reg[0] =
332       vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0]));
333   result.buf.reg[1] =
334       vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1]));
335   result.buf.reg[2] =
336       vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0]));
337   result.buf.reg[3] =
338       vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1]));
339   return result;
340 }
341 
342 template <typename DstType>
343 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
344   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
345                   int col) {
346     const auto& block =
347         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
348     std::int32_t* dst_ptr = dst->data(row, col);
349     int stride = dst->stride();
350     for (int i = 0; i < 4; i++) {
351       vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]);
352     }
353   }
354 };
355 
356 template <typename DstType>
357 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
358   static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
359                   int col) {
360     if (DstType::kOrder == MapOrder::ColMajor) {
361       vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0]));
362       vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0]));
363       vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1]));
364       vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1]));
365     } else {
366       const int16x4x2_t t0 =
367           vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0]));
368       const int16x4x2_t t1 =
369           vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1]));
370       const int32x4x2_t t =
371           vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])),
372                     vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1])));
373       vst1_s16(dst->data(row + 0, col),
374                vget_low_s16(vreinterpretq_s16_s32(t.val[0])));
375       vst1_s16(dst->data(row + 1, col),
376                vget_high_s16(vreinterpretq_s16_s32(t.val[0])));
377       vst1_s16(dst->data(row + 2, col),
378                vget_low_s16(vreinterpretq_s16_s32(t.val[1])));
379       vst1_s16(dst->data(row + 3, col),
380                vget_high_s16(vreinterpretq_s16_s32(t.val[1])));
381     }
382   }
383 };
384 
385 template <typename DstType>
386 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
387   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
388                   int col) {
389     std::int32_t* dst_ptr = dst->data(row, col);
390     if (DstType::kOrder == MapOrder::ColMajor) {
391       int col_stride = dst->cols_stride();
392       for (int i = 0; i < 4; i++) {
393         vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]);
394         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
395       }
396     } else {
397       int row_stride = dst->rows_stride();
398       RegBlockInt32<4, 4> top;
399       top.buf.reg[0] = src.buf.reg[0];
400       top.buf.reg[1] = src.buf.reg[2];
401       top.buf.reg[2] = src.buf.reg[4];
402       top.buf.reg[3] = src.buf.reg[6];
403       const auto transpose_top = Transpose(top);
404       for (int i = 0; i < 4; i++) {
405         vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]);
406       }
407       RegBlockInt32<4, 4> bottom;
408       bottom.buf.reg[0] = src.buf.reg[1];
409       bottom.buf.reg[1] = src.buf.reg[3];
410       bottom.buf.reg[2] = src.buf.reg[5];
411       bottom.buf.reg[3] = src.buf.reg[7];
412       const auto transpose_bottom = Transpose(bottom);
413       for (int i = 0; i < 4; i++) {
414         vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]);
415       }
416     }
417   }
418 };
419 
420 template <typename DstType>
421 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
422   static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
423                   int col) {
424     if (DstType::kOrder == MapOrder::ColMajor) {
425       vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
426       vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
427       vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
428       vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
429     } else {
430       const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
431       const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
432       const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]),
433                                        vreinterpretq_s32_s16(t1.val[0]));
434       const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]),
435                                        vreinterpretq_s32_s16(t1.val[1]));
436       vst1_s16(dst->data(row + 0, col),
437                vget_low_s16(vreinterpretq_s16_s32(u0.val[0])));
438       vst1_s16(dst->data(row + 1, col),
439                vget_low_s16(vreinterpretq_s16_s32(u1.val[0])));
440       vst1_s16(dst->data(row + 2, col),
441                vget_low_s16(vreinterpretq_s16_s32(u0.val[1])));
442       vst1_s16(dst->data(row + 3, col),
443                vget_low_s16(vreinterpretq_s16_s32(u1.val[1])));
444       vst1_s16(dst->data(row + 4, col),
445                vget_high_s16(vreinterpretq_s16_s32(u0.val[0])));
446       vst1_s16(dst->data(row + 5, col),
447                vget_high_s16(vreinterpretq_s16_s32(u1.val[0])));
448       vst1_s16(dst->data(row + 6, col),
449                vget_high_s16(vreinterpretq_s16_s32(u0.val[1])));
450       vst1_s16(dst->data(row + 7, col),
451                vget_high_s16(vreinterpretq_s16_s32(u1.val[1])));
452     }
453   }
454 };
455 
456 template <typename DstType>
457 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
458   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
459                   int col) {
460     std::int32_t* dst_ptr = dst->data(row, col);
461     if (DstType::kOrder == MapOrder::ColMajor) {
462       int col_stride = dst->cols_stride();
463       for (int i = 0; i < 8; i++) {
464         vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]);
465         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
466       }
467     } else {
468       int row_stride = dst->rows_stride();
469       RegBlockInt32<4, 4> top_left;
470       top_left.buf.reg[0] = src.buf.reg[0];
471       top_left.buf.reg[1] = src.buf.reg[2];
472       top_left.buf.reg[2] = src.buf.reg[4];
473       top_left.buf.reg[3] = src.buf.reg[6];
474       const auto transpose_top_left = Transpose(top_left);
475       for (int i = 0; i < 4; i++) {
476         vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]);
477       }
478       RegBlockInt32<4, 4> bottom_left;
479       bottom_left.buf.reg[0] = src.buf.reg[1];
480       bottom_left.buf.reg[1] = src.buf.reg[3];
481       bottom_left.buf.reg[2] = src.buf.reg[5];
482       bottom_left.buf.reg[3] = src.buf.reg[7];
483       const auto transpose_bottom_left = Transpose(bottom_left);
484       for (int i = 0; i < 4; i++) {
485         vst1q_s32(dst_ptr + (i + 4) * row_stride,
486                   transpose_bottom_left.buf.reg[i]);
487       }
488       RegBlockInt32<4, 4> top_right;
489       top_right.buf.reg[0] = src.buf.reg[8];
490       top_right.buf.reg[1] = src.buf.reg[10];
491       top_right.buf.reg[2] = src.buf.reg[12];
492       top_right.buf.reg[3] = src.buf.reg[14];
493       const auto transpose_top_right = Transpose(top_right);
494       for (int i = 0; i < 4; i++) {
495         vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]);
496       }
497       RegBlockInt32<4, 4> bottom_right;
498       bottom_right.buf.reg[0] = src.buf.reg[9];
499       bottom_right.buf.reg[1] = src.buf.reg[11];
500       bottom_right.buf.reg[2] = src.buf.reg[13];
501       bottom_right.buf.reg[3] = src.buf.reg[15];
502       const auto transpose_bottom_right = Transpose(bottom_right);
503       for (int i = 0; i < 4; i++) {
504         vst1q_s32(dst_ptr + (i + 4) * row_stride + 4,
505                   transpose_bottom_right.buf.reg[i]);
506       }
507     }
508   }
509 };
510 
511 template <typename DstType>
512 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
513   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
514                   int col) {
515     std::int32_t* dst_ptr = dst->data(row, col);
516     if (DstType::kOrder == MapOrder::ColMajor) {
517       vst1q_s32(dst_ptr, src.buf.reg[0]);
518     } else {
519       int row_stride = dst->rows_stride();
520       vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
521       vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
522       vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
523       vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
524     }
525   }
526 };
527 
528 template <typename DstType>
529 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
530   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
531                   int col) {
532     std::int32_t* dst_ptr = dst->data(row, col);
533     if (DstType::kOrder == MapOrder::RowMajor) {
534       vst1q_s32(dst_ptr, src.buf.reg[0]);
535     } else {
536       int col_stride = dst->cols_stride();
537       vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
538       vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
539       vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
540       vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
541     }
542   }
543 };
544 
545 template <typename DstType>
546 struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> {
547   static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row,
548                   int col) {
549     std::int16_t* dst_ptr = dst->data(row, col);
550     if (DstType::kOrder == MapOrder::RowMajor) {
551       vst1_s16(dst_ptr, src.buf.reg[0]);
552     } else {
553       int col_stride = dst->cols_stride();
554       vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
555       vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
556       vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
557       vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
558     }
559   }
560 };
561 
562 template <typename DstType>
563 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
564   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
565                   int col) {
566     const std::uint32_t src_reg = src.buf.reg[0];
567     for (int i = 0; i < 4; i++) {
568       *dst->data(row + i, col) = (src_reg >> (8 * i));
569     }
570   }
571 };
572 
573 template <typename DstType>
574 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
575   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
576                   int col) {
577     for (int i = 0; i < 4; i++) {
578       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
579     }
580   }
581 };
582 
583 template <typename DstType>
584 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
585   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
586                   int col) {
587     std::uint8_t* dst_ptr = dst->data(row, col);
588     if (DstType::kOrder == MapOrder::ColMajor) {
589       vst1_u8(dst_ptr, src.buf.reg[0]);
590     } else {
591       const int row_stride = dst->rows_stride();
592       vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
593       vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
594       vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
595       vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
596       vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
597       vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
598       vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
599       vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
600     }
601   }
602 };
603 
604 template <typename DstType>
605 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
606   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
607                   int col) {
608     std::uint8_t* dst_ptr = dst->data(row, col);
609     const int row_stride = dst->rows_stride();
610     const int col_stride = dst->cols_stride();
611     for (int i = 0; i < 2; i++) {
612       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
613                    src.buf.reg[i], 0);
614       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
615                    src.buf.reg[i], 1);
616       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
617                    src.buf.reg[i], 2);
618       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
619                    src.buf.reg[i], 3);
620       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
621                    src.buf.reg[i], 4);
622       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
623                    src.buf.reg[i], 5);
624       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
625                    src.buf.reg[i], 6);
626       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
627                    src.buf.reg[i], 7);
628     }
629   }
630 };
631 
632 template <typename DstType>
633 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
634   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
635                   int col) {
636     std::uint8_t* dst_ptr = dst->data(row, col);
637     if (DstType::kOrder == MapOrder::ColMajor) {
638       int col_stride = dst->cols_stride();
639       for (int i = 0; i < 4; i++) {
640         vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
641       }
642     } else {
643       int row_stride = dst->rows_stride();
644       for (int i = 0; i < 4; i++) {
645         std::uint8_t* col_ptr = dst_ptr + i;
646         vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
647         vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
648         vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
649         vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
650         vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
651         vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
652         vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
653         vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
654       }
655     }
656   }
657 };
658 
659 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
660   uint8x8x2_t a[4];
661   a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]);
662   a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]);
663   a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]);
664   a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]);
665   uint16x4x2_t b[4];
666   b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]),
667                   vreinterpret_u16_u8(a[1].val[0]));
668   b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]),
669                   vreinterpret_u16_u8(a[1].val[1]));
670   b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]),
671                   vreinterpret_u16_u8(a[3].val[0]));
672   b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]),
673                   vreinterpret_u16_u8(a[3].val[1]));
674   uint32x2x2_t c[4];
675   c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]),
676                   vreinterpret_u32_u16(b[2].val[0]));
677   c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]),
678                   vreinterpret_u32_u16(b[3].val[0]));
679   c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]),
680                   vreinterpret_u32_u16(b[2].val[1]));
681   c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]),
682                   vreinterpret_u32_u16(b[3].val[1]));
683   RegBlockUint8<8, 8> result;
684   result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]);
685   result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]);
686   result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]);
687   result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]);
688   result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]);
689   result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]);
690   result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]);
691   result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]);
692   return result;
693 }
694 
695 template <typename DstType>
696 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
697   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
698                   int col) {
699     const auto& block =
700         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
701     std::uint8_t* dst_ptr = dst->data(row, col);
702     int stride = dst->stride();
703     for (int i = 0; i < 8; i++) {
704       vst1_u8(dst_ptr + i * stride, block.buf.reg[i]);
705     }
706   }
707 };
708 
709 template <typename DstType>
710 struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> {
711   static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row,
712                   int col) {
713     const std::int32_t src_reg = src.buf.reg[0];
714     for (int i = 0; i < 4; i++) {
715       *dst->data(row + i, col) = (src_reg >> (8 * i));
716     }
717   }
718 };
719 
720 template <typename DstType>
721 struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> {
722   static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row,
723                   int col) {
724     for (int i = 0; i < 4; i++) {
725       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
726     }
727   }
728 };
729 
730 template <typename DstType>
731 struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> {
732   static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row,
733                   int col) {
734     std::int8_t* dst_ptr = dst->data(row, col);
735     if (DstType::kOrder == MapOrder::ColMajor) {
736       vst1_s8(dst_ptr, src.buf.reg[0]);
737     } else {
738       const int row_stride = dst->rows_stride();
739       vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
740       vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
741       vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
742       vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
743       vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
744       vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
745       vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
746       vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
747     }
748   }
749 };
750 
751 template <typename DstType>
752 struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> {
753   static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row,
754                   int col) {
755     std::int8_t* dst_ptr = dst->data(row, col);
756     const int row_stride = dst->rows_stride();
757     const int col_stride = dst->cols_stride();
758     for (int i = 0; i < 2; i++) {
759       vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
760                    src.buf.reg[i], 0);
761       vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
762                    src.buf.reg[i], 1);
763       vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
764                    src.buf.reg[i], 2);
765       vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
766                    src.buf.reg[i], 3);
767       vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
768                    src.buf.reg[i], 4);
769       vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
770                    src.buf.reg[i], 5);
771       vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
772                    src.buf.reg[i], 6);
773       vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
774                    src.buf.reg[i], 7);
775     }
776   }
777 };
778 
779 template <typename DstType>
780 struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> {
781   static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row,
782                   int col) {
783     std::int8_t* dst_ptr = dst->data(row, col);
784     if (DstType::kOrder == MapOrder::ColMajor) {
785       int col_stride = dst->cols_stride();
786       for (int i = 0; i < 4; i++) {
787         vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]);
788       }
789     } else {
790       int row_stride = dst->rows_stride();
791       for (int i = 0; i < 4; i++) {
792         std::int8_t* col_ptr = dst_ptr + i;
793         vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
794         vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
795         vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
796         vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
797         vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
798         vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
799         vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
800         vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
801       }
802     }
803   }
804 };
805 
806 inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) {
807   int8x8x2_t a[4];
808   a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]);
809   a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]);
810   a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]);
811   a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]);
812   int16x4x2_t b[4];
813   b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]),
814                   vreinterpret_s16_s8(a[1].val[0]));
815   b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]),
816                   vreinterpret_s16_s8(a[1].val[1]));
817   b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]),
818                   vreinterpret_s16_s8(a[3].val[0]));
819   b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]),
820                   vreinterpret_s16_s8(a[3].val[1]));
821   int32x2x2_t c[4];
822   c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]),
823                   vreinterpret_s32_s16(b[2].val[0]));
824   c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]),
825                   vreinterpret_s32_s16(b[3].val[0]));
826   c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]),
827                   vreinterpret_s32_s16(b[2].val[1]));
828   c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]),
829                   vreinterpret_s32_s16(b[3].val[1]));
830   RegBlockInt8<8, 8> result;
831   result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]);
832   result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]);
833   result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]);
834   result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]);
835   result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]);
836   result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]);
837   result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]);
838   result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]);
839   return result;
840 }
841 
842 template <typename DstType>
843 struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> {
844   static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row,
845                   int col) {
846     const auto& block =
847         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
848     std::int8_t* dst_ptr = dst->data(row, col);
849     int stride = dst->stride();
850     for (int i = 0; i < 8; i++) {
851       vst1_s8(dst_ptr + i * stride, block.buf.reg[i]);
852     }
853   }
854 };
855 
856 template <typename DstType>
857 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
858   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
859                   int col) {
860     if (DstType::kOrder == MapOrder::ColMajor) {
861       vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
862       vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
863       vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
864       vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
865       vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]);
866       vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]);
867       vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]);
868       vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]);
869     } else {
870       int16x8x2_t a[4];
871       a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
872       a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
873       a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]);
874       a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]);
875       int32x4x2_t b[4];
876       b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]),
877                        vreinterpretq_s32_s16(a[1].val[0]));
878       b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]),
879                        vreinterpretq_s32_s16(a[1].val[1]));
880       b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]),
881                        vreinterpretq_s32_s16(a[3].val[0]));
882       b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]),
883                        vreinterpretq_s32_s16(a[3].val[1]));
884       vst1_s16(dst->data(row + 0, col + 0),
885                vget_low_s16(vreinterpretq_s16_s32(b[0].val[0])));
886       vst1_s16(dst->data(row + 0, col + 4),
887                vget_low_s16(vreinterpretq_s16_s32(b[2].val[0])));
888       vst1_s16(dst->data(row + 1, col + 0),
889                vget_low_s16(vreinterpretq_s16_s32(b[1].val[0])));
890       vst1_s16(dst->data(row + 1, col + 4),
891                vget_low_s16(vreinterpretq_s16_s32(b[3].val[0])));
892       vst1_s16(dst->data(row + 2, col + 0),
893                vget_low_s16(vreinterpretq_s16_s32(b[0].val[1])));
894       vst1_s16(dst->data(row + 2, col + 4),
895                vget_low_s16(vreinterpretq_s16_s32(b[2].val[1])));
896       vst1_s16(dst->data(row + 3, col + 0),
897                vget_low_s16(vreinterpretq_s16_s32(b[1].val[1])));
898       vst1_s16(dst->data(row + 3, col + 4),
899                vget_low_s16(vreinterpretq_s16_s32(b[3].val[1])));
900       vst1_s16(dst->data(row + 4, col + 0),
901                vget_high_s16(vreinterpretq_s16_s32(b[0].val[0])));
902       vst1_s16(dst->data(row + 4, col + 4),
903                vget_high_s16(vreinterpretq_s16_s32(b[2].val[0])));
904       vst1_s16(dst->data(row + 5, col + 0),
905                vget_high_s16(vreinterpretq_s16_s32(b[1].val[0])));
906       vst1_s16(dst->data(row + 5, col + 4),
907                vget_high_s16(vreinterpretq_s16_s32(b[3].val[0])));
908       vst1_s16(dst->data(row + 6, col + 0),
909                vget_high_s16(vreinterpretq_s16_s32(b[0].val[1])));
910       vst1_s16(dst->data(row + 6, col + 4),
911                vget_high_s16(vreinterpretq_s16_s32(b[2].val[1])));
912       vst1_s16(dst->data(row + 7, col + 0),
913                vget_high_s16(vreinterpretq_s16_s32(b[1].val[1])));
914       vst1_s16(dst->data(row + 7, col + 4),
915                vget_high_s16(vreinterpretq_s16_s32(b[3].val[1])));
916     }
917   }
918 };
919 
920 }  // namespace gemmlowp
921 
922 #endif  // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
923