1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output_neon.h: optimized NEON specializations of the templates in output.h. 16 17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 19 20 #include "output.h" 21 22 #include <arm_neon.h> 23 24 namespace gemmlowp { 25 26 template <> 27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 int16x4_t res_16 = vqmovn_s32(input.reg[0]); 39 uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16)); 40 output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0); 41 return output; 42 } 43 }; 44 45 template <> 46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 47 RegBufferInt32<8>> { 48 typedef RegBufferInt32<8> InputType; 49 typedef RegBufferUint8<8> OutputType; 50 51 typedef OutputStageSaturatingCastToUint8 OutputStage; 52 53 OutputStageEvalBufferImpl(const OutputStage&) {} 54 55 OutputType Eval(InputType input) const { 56 OutputType output; 57 int16x8_t res_16 = 58 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 59 output.reg[0] = vqmovun_s16(res_16); 60 return output; 61 } 62 }; 63 64 template <> 65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 66 RegBufferInt32<16>> { 67 typedef RegBufferInt32<16> InputType; 68 typedef RegBufferUint8<16> OutputType; 69 70 typedef OutputStageSaturatingCastToUint8 OutputStage; 71 72 OutputStageEvalBufferImpl(const OutputStage&) {} 73 74 OutputType Eval(InputType input) const { 75 OutputType output; 76 int16x8_t res_16_0 = 77 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 78 int16x8_t res_16_1 = 79 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 80 output.reg[0] = vqmovun_s16(res_16_0); 81 output.reg[1] = vqmovun_s16(res_16_1); 82 return output; 83 } 84 }; 85 86 template <> 87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 88 RegBufferInt32<32>> { 89 typedef RegBufferInt32<32> InputType; 90 typedef RegBufferUint8<32> OutputType; 91 92 typedef OutputStageSaturatingCastToUint8 OutputStage; 93 94 OutputStageEvalBufferImpl(const OutputStage&) {} 95 96 OutputType Eval(InputType input) const { 97 OutputType output; 98 int16x8_t res_16[4]; 99 for (int i = 0; i < 4; i++) { 100 res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), 101 vqmovn_s32(input.reg[2 * i + 1])); 102 } 103 for (int i = 0; i < 4; i++) { 104 output.reg[i] = vqmovun_s16(res_16[i]); 105 } 106 return output; 107 } 108 }; 109 110 template <> 111 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 112 RegBufferInt32<4>> { 113 typedef RegBufferInt32<4> InputType; 114 typedef RegBufferInt8<4> OutputType; 115 116 typedef OutputStageSaturatingCastToInt8 OutputStage; 117 118 OutputStageEvalBufferImpl(const OutputStage&) {} 119 120 OutputType Eval(InputType input) const { 121 OutputType output; 122 int16x4_t res_16 = vqmovn_s32(input.reg[0]); 123 int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16)); 124 output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0); 125 return output; 126 } 127 }; 128 129 template <> 130 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 131 RegBufferInt32<8>> { 132 typedef RegBufferInt32<8> InputType; 133 typedef RegBufferInt8<8> OutputType; 134 135 typedef OutputStageSaturatingCastToInt8 OutputStage; 136 137 OutputStageEvalBufferImpl(const OutputStage&) {} 138 139 OutputType Eval(InputType input) const { 140 OutputType output; 141 int16x8_t res_16 = 142 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 143 output.reg[0] = vqmovn_s16(res_16); 144 return output; 145 } 146 }; 147 148 template <> 149 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 150 RegBufferInt32<16>> { 151 typedef RegBufferInt32<16> InputType; 152 typedef RegBufferInt8<16> OutputType; 153 154 typedef OutputStageSaturatingCastToInt8 OutputStage; 155 156 OutputStageEvalBufferImpl(const OutputStage&) {} 157 158 OutputType Eval(InputType input) const { 159 OutputType output; 160 int16x8_t res_16_0 = 161 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 162 int16x8_t res_16_1 = 163 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 164 output.reg[0] = vqmovn_s16(res_16_0); 165 output.reg[1] = vqmovn_s16(res_16_1); 166 return output; 167 } 168 }; 169 170 template <> 171 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 172 RegBufferInt32<32>> { 173 typedef RegBufferInt32<32> InputType; 174 typedef RegBufferInt8<32> OutputType; 175 176 typedef OutputStageSaturatingCastToInt8 OutputStage; 177 178 OutputStageEvalBufferImpl(const OutputStage&) {} 179 180 OutputType Eval(InputType input) const { 181 OutputType output; 182 int16x8_t res_16[4]; 183 for (int i = 0; i < 4; i++) { 184 res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), 185 vqmovn_s32(input.reg[2 * i + 1])); 186 } 187 for (int i = 0; i < 4; i++) { 188 output.reg[i] = vqmovn_s16(res_16[i]); 189 } 190 return output; 191 } 192 }; 193 194 template <> 195 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 196 RegBufferInt32<4>> { 197 typedef RegBufferInt32<4> InputType; 198 typedef RegBufferInt16<4> OutputType; 199 200 typedef OutputStageSaturatingCastToInt16 OutputStage; 201 202 OutputStageEvalBufferImpl(const OutputStage&) {} 203 204 OutputType Eval(InputType input) const { 205 OutputType output; 206 output.reg[0] = vqmovn_s32(input.reg[0]); 207 return output; 208 } 209 }; 210 211 template <> 212 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 213 RegBufferInt32<8>> { 214 typedef RegBufferInt32<8> InputType; 215 typedef RegBufferInt16<8> OutputType; 216 217 typedef OutputStageSaturatingCastToInt16 OutputStage; 218 219 OutputStageEvalBufferImpl(const OutputStage&) {} 220 221 OutputType Eval(InputType input) const { 222 OutputType output; 223 output.reg[0] = 224 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 225 return output; 226 } 227 }; 228 229 template <> 230 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 231 RegBufferInt32<16>> { 232 typedef RegBufferInt32<16> InputType; 233 typedef RegBufferInt16<16> OutputType; 234 235 typedef OutputStageSaturatingCastToInt16 OutputStage; 236 237 OutputStageEvalBufferImpl(const OutputStage&) {} 238 239 OutputType Eval(InputType input) const { 240 OutputType output; 241 output.reg[0] = 242 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 243 output.reg[1] = 244 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 245 return output; 246 } 247 }; 248 249 template <> 250 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 251 RegBufferInt32<32>> { 252 typedef RegBufferInt32<32> InputType; 253 typedef RegBufferInt16<32> OutputType; 254 255 typedef OutputStageSaturatingCastToInt16 OutputStage; 256 257 OutputStageEvalBufferImpl(const OutputStage&) {} 258 259 OutputType Eval(InputType input) const { 260 OutputType output; 261 output.reg[0] = 262 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 263 output.reg[1] = 264 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 265 output.reg[2] = 266 vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5])); 267 output.reg[3] = 268 vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7])); 269 return output; 270 } 271 }; 272 273 template <typename DstType> 274 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 275 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 276 int col) { 277 if (DstType::kOrder == MapOrder::ColMajor) { 278 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 279 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 280 } else { 281 vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0); 282 vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1); 283 vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2); 284 vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3); 285 vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0); 286 vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1); 287 vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2); 288 vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3); 289 } 290 } 291 }; 292 293 template <typename DstType> 294 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { 295 static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, 296 int col) { 297 if (DstType::kOrder == MapOrder::ColMajor) { 298 StoreInt16x4(dst->data(row, col), src.buf.reg[0]); 299 } else { 300 vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0); 301 vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1); 302 vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2); 303 vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3); 304 } 305 } 306 }; 307 308 template <typename DstType> 309 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { 310 static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, 311 int col) { 312 if (DstType::kOrder == MapOrder::ColMajor) { 313 StoreInt16x8(dst->data(row, col), src.buf.reg[0]); 314 } else { 315 vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0); 316 vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1); 317 vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2); 318 vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3); 319 vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4); 320 vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5); 321 vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6); 322 vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7); 323 } 324 } 325 }; 326 327 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 328 const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]); 329 const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]); 330 RegBlockInt32<4, 4> result; 331 result.buf.reg[0] = 332 vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0])); 333 result.buf.reg[1] = 334 vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1])); 335 result.buf.reg[2] = 336 vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0])); 337 result.buf.reg[3] = 338 vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1])); 339 return result; 340 } 341 342 template <typename DstType> 343 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 344 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 345 int col) { 346 const auto& block = 347 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 348 std::int32_t* dst_ptr = dst->data(row, col); 349 int stride = dst->stride(); 350 for (int i = 0; i < 4; i++) { 351 vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]); 352 } 353 } 354 }; 355 356 template <typename DstType> 357 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { 358 static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, 359 int col) { 360 if (DstType::kOrder == MapOrder::ColMajor) { 361 vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0])); 362 vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0])); 363 vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1])); 364 vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1])); 365 } else { 366 const int16x4x2_t t0 = 367 vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0])); 368 const int16x4x2_t t1 = 369 vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1])); 370 const int32x4x2_t t = 371 vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])), 372 vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1]))); 373 vst1_s16(dst->data(row + 0, col), 374 vget_low_s16(vreinterpretq_s16_s32(t.val[0]))); 375 vst1_s16(dst->data(row + 1, col), 376 vget_high_s16(vreinterpretq_s16_s32(t.val[0]))); 377 vst1_s16(dst->data(row + 2, col), 378 vget_low_s16(vreinterpretq_s16_s32(t.val[1]))); 379 vst1_s16(dst->data(row + 3, col), 380 vget_high_s16(vreinterpretq_s16_s32(t.val[1]))); 381 } 382 } 383 }; 384 385 template <typename DstType> 386 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 387 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 388 int col) { 389 std::int32_t* dst_ptr = dst->data(row, col); 390 if (DstType::kOrder == MapOrder::ColMajor) { 391 int col_stride = dst->cols_stride(); 392 for (int i = 0; i < 4; i++) { 393 vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]); 394 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 395 } 396 } else { 397 int row_stride = dst->rows_stride(); 398 RegBlockInt32<4, 4> top; 399 top.buf.reg[0] = src.buf.reg[0]; 400 top.buf.reg[1] = src.buf.reg[2]; 401 top.buf.reg[2] = src.buf.reg[4]; 402 top.buf.reg[3] = src.buf.reg[6]; 403 const auto transpose_top = Transpose(top); 404 for (int i = 0; i < 4; i++) { 405 vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]); 406 } 407 RegBlockInt32<4, 4> bottom; 408 bottom.buf.reg[0] = src.buf.reg[1]; 409 bottom.buf.reg[1] = src.buf.reg[3]; 410 bottom.buf.reg[2] = src.buf.reg[5]; 411 bottom.buf.reg[3] = src.buf.reg[7]; 412 const auto transpose_bottom = Transpose(bottom); 413 for (int i = 0; i < 4; i++) { 414 vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]); 415 } 416 } 417 } 418 }; 419 420 template <typename DstType> 421 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { 422 static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, 423 int col) { 424 if (DstType::kOrder == MapOrder::ColMajor) { 425 vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]); 426 vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]); 427 vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]); 428 vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]); 429 } else { 430 const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]); 431 const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]); 432 const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]), 433 vreinterpretq_s32_s16(t1.val[0])); 434 const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]), 435 vreinterpretq_s32_s16(t1.val[1])); 436 vst1_s16(dst->data(row + 0, col), 437 vget_low_s16(vreinterpretq_s16_s32(u0.val[0]))); 438 vst1_s16(dst->data(row + 1, col), 439 vget_low_s16(vreinterpretq_s16_s32(u1.val[0]))); 440 vst1_s16(dst->data(row + 2, col), 441 vget_low_s16(vreinterpretq_s16_s32(u0.val[1]))); 442 vst1_s16(dst->data(row + 3, col), 443 vget_low_s16(vreinterpretq_s16_s32(u1.val[1]))); 444 vst1_s16(dst->data(row + 4, col), 445 vget_high_s16(vreinterpretq_s16_s32(u0.val[0]))); 446 vst1_s16(dst->data(row + 5, col), 447 vget_high_s16(vreinterpretq_s16_s32(u1.val[0]))); 448 vst1_s16(dst->data(row + 6, col), 449 vget_high_s16(vreinterpretq_s16_s32(u0.val[1]))); 450 vst1_s16(dst->data(row + 7, col), 451 vget_high_s16(vreinterpretq_s16_s32(u1.val[1]))); 452 } 453 } 454 }; 455 456 template <typename DstType> 457 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 458 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 459 int col) { 460 std::int32_t* dst_ptr = dst->data(row, col); 461 if (DstType::kOrder == MapOrder::ColMajor) { 462 int col_stride = dst->cols_stride(); 463 for (int i = 0; i < 8; i++) { 464 vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]); 465 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 466 } 467 } else { 468 int row_stride = dst->rows_stride(); 469 RegBlockInt32<4, 4> top_left; 470 top_left.buf.reg[0] = src.buf.reg[0]; 471 top_left.buf.reg[1] = src.buf.reg[2]; 472 top_left.buf.reg[2] = src.buf.reg[4]; 473 top_left.buf.reg[3] = src.buf.reg[6]; 474 const auto transpose_top_left = Transpose(top_left); 475 for (int i = 0; i < 4; i++) { 476 vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]); 477 } 478 RegBlockInt32<4, 4> bottom_left; 479 bottom_left.buf.reg[0] = src.buf.reg[1]; 480 bottom_left.buf.reg[1] = src.buf.reg[3]; 481 bottom_left.buf.reg[2] = src.buf.reg[5]; 482 bottom_left.buf.reg[3] = src.buf.reg[7]; 483 const auto transpose_bottom_left = Transpose(bottom_left); 484 for (int i = 0; i < 4; i++) { 485 vst1q_s32(dst_ptr + (i + 4) * row_stride, 486 transpose_bottom_left.buf.reg[i]); 487 } 488 RegBlockInt32<4, 4> top_right; 489 top_right.buf.reg[0] = src.buf.reg[8]; 490 top_right.buf.reg[1] = src.buf.reg[10]; 491 top_right.buf.reg[2] = src.buf.reg[12]; 492 top_right.buf.reg[3] = src.buf.reg[14]; 493 const auto transpose_top_right = Transpose(top_right); 494 for (int i = 0; i < 4; i++) { 495 vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]); 496 } 497 RegBlockInt32<4, 4> bottom_right; 498 bottom_right.buf.reg[0] = src.buf.reg[9]; 499 bottom_right.buf.reg[1] = src.buf.reg[11]; 500 bottom_right.buf.reg[2] = src.buf.reg[13]; 501 bottom_right.buf.reg[3] = src.buf.reg[15]; 502 const auto transpose_bottom_right = Transpose(bottom_right); 503 for (int i = 0; i < 4; i++) { 504 vst1q_s32(dst_ptr + (i + 4) * row_stride + 4, 505 transpose_bottom_right.buf.reg[i]); 506 } 507 } 508 } 509 }; 510 511 template <typename DstType> 512 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 513 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 514 int col) { 515 std::int32_t* dst_ptr = dst->data(row, col); 516 if (DstType::kOrder == MapOrder::ColMajor) { 517 vst1q_s32(dst_ptr, src.buf.reg[0]); 518 } else { 519 int row_stride = dst->rows_stride(); 520 vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 521 vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 522 vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 523 vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 524 } 525 } 526 }; 527 528 template <typename DstType> 529 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 530 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 531 int col) { 532 std::int32_t* dst_ptr = dst->data(row, col); 533 if (DstType::kOrder == MapOrder::RowMajor) { 534 vst1q_s32(dst_ptr, src.buf.reg[0]); 535 } else { 536 int col_stride = dst->cols_stride(); 537 vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); 538 vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); 539 vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); 540 vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); 541 } 542 } 543 }; 544 545 template <typename DstType> 546 struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> { 547 static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row, 548 int col) { 549 std::int16_t* dst_ptr = dst->data(row, col); 550 if (DstType::kOrder == MapOrder::RowMajor) { 551 vst1_s16(dst_ptr, src.buf.reg[0]); 552 } else { 553 int col_stride = dst->cols_stride(); 554 vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); 555 vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); 556 vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); 557 vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); 558 } 559 } 560 }; 561 562 template <typename DstType> 563 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 564 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 565 int col) { 566 const std::uint32_t src_reg = src.buf.reg[0]; 567 for (int i = 0; i < 4; i++) { 568 *dst->data(row + i, col) = (src_reg >> (8 * i)); 569 } 570 } 571 }; 572 573 template <typename DstType> 574 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 575 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 576 int col) { 577 for (int i = 0; i < 4; i++) { 578 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 579 } 580 } 581 }; 582 583 template <typename DstType> 584 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 585 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 586 int col) { 587 std::uint8_t* dst_ptr = dst->data(row, col); 588 if (DstType::kOrder == MapOrder::ColMajor) { 589 vst1_u8(dst_ptr, src.buf.reg[0]); 590 } else { 591 const int row_stride = dst->rows_stride(); 592 vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 593 vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 594 vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 595 vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 596 vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); 597 vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); 598 vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); 599 vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); 600 } 601 } 602 }; 603 604 template <typename DstType> 605 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 606 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 607 int col) { 608 std::uint8_t* dst_ptr = dst->data(row, col); 609 const int row_stride = dst->rows_stride(); 610 const int col_stride = dst->cols_stride(); 611 for (int i = 0; i < 2; i++) { 612 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, 613 src.buf.reg[i], 0); 614 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, 615 src.buf.reg[i], 1); 616 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, 617 src.buf.reg[i], 2); 618 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, 619 src.buf.reg[i], 3); 620 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, 621 src.buf.reg[i], 4); 622 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, 623 src.buf.reg[i], 5); 624 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, 625 src.buf.reg[i], 6); 626 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, 627 src.buf.reg[i], 7); 628 } 629 } 630 }; 631 632 template <typename DstType> 633 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 634 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 635 int col) { 636 std::uint8_t* dst_ptr = dst->data(row, col); 637 if (DstType::kOrder == MapOrder::ColMajor) { 638 int col_stride = dst->cols_stride(); 639 for (int i = 0; i < 4; i++) { 640 vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); 641 } 642 } else { 643 int row_stride = dst->rows_stride(); 644 for (int i = 0; i < 4; i++) { 645 std::uint8_t* col_ptr = dst_ptr + i; 646 vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); 647 vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); 648 vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); 649 vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); 650 vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); 651 vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); 652 vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); 653 vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); 654 } 655 } 656 } 657 }; 658 659 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { 660 uint8x8x2_t a[4]; 661 a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]); 662 a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]); 663 a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]); 664 a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]); 665 uint16x4x2_t b[4]; 666 b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]), 667 vreinterpret_u16_u8(a[1].val[0])); 668 b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]), 669 vreinterpret_u16_u8(a[1].val[1])); 670 b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]), 671 vreinterpret_u16_u8(a[3].val[0])); 672 b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]), 673 vreinterpret_u16_u8(a[3].val[1])); 674 uint32x2x2_t c[4]; 675 c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]), 676 vreinterpret_u32_u16(b[2].val[0])); 677 c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]), 678 vreinterpret_u32_u16(b[3].val[0])); 679 c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]), 680 vreinterpret_u32_u16(b[2].val[1])); 681 c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]), 682 vreinterpret_u32_u16(b[3].val[1])); 683 RegBlockUint8<8, 8> result; 684 result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]); 685 result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]); 686 result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]); 687 result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]); 688 result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]); 689 result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]); 690 result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]); 691 result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]); 692 return result; 693 } 694 695 template <typename DstType> 696 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 697 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 698 int col) { 699 const auto& block = 700 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 701 std::uint8_t* dst_ptr = dst->data(row, col); 702 int stride = dst->stride(); 703 for (int i = 0; i < 8; i++) { 704 vst1_u8(dst_ptr + i * stride, block.buf.reg[i]); 705 } 706 } 707 }; 708 709 template <typename DstType> 710 struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> { 711 static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row, 712 int col) { 713 const std::int32_t src_reg = src.buf.reg[0]; 714 for (int i = 0; i < 4; i++) { 715 *dst->data(row + i, col) = (src_reg >> (8 * i)); 716 } 717 } 718 }; 719 720 template <typename DstType> 721 struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> { 722 static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row, 723 int col) { 724 for (int i = 0; i < 4; i++) { 725 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 726 } 727 } 728 }; 729 730 template <typename DstType> 731 struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> { 732 static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row, 733 int col) { 734 std::int8_t* dst_ptr = dst->data(row, col); 735 if (DstType::kOrder == MapOrder::ColMajor) { 736 vst1_s8(dst_ptr, src.buf.reg[0]); 737 } else { 738 const int row_stride = dst->rows_stride(); 739 vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 740 vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 741 vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 742 vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 743 vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); 744 vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); 745 vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); 746 vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); 747 } 748 } 749 }; 750 751 template <typename DstType> 752 struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> { 753 static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row, 754 int col) { 755 std::int8_t* dst_ptr = dst->data(row, col); 756 const int row_stride = dst->rows_stride(); 757 const int col_stride = dst->cols_stride(); 758 for (int i = 0; i < 2; i++) { 759 vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, 760 src.buf.reg[i], 0); 761 vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, 762 src.buf.reg[i], 1); 763 vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, 764 src.buf.reg[i], 2); 765 vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, 766 src.buf.reg[i], 3); 767 vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, 768 src.buf.reg[i], 4); 769 vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, 770 src.buf.reg[i], 5); 771 vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, 772 src.buf.reg[i], 6); 773 vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, 774 src.buf.reg[i], 7); 775 } 776 } 777 }; 778 779 template <typename DstType> 780 struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> { 781 static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row, 782 int col) { 783 std::int8_t* dst_ptr = dst->data(row, col); 784 if (DstType::kOrder == MapOrder::ColMajor) { 785 int col_stride = dst->cols_stride(); 786 for (int i = 0; i < 4; i++) { 787 vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]); 788 } 789 } else { 790 int row_stride = dst->rows_stride(); 791 for (int i = 0; i < 4; i++) { 792 std::int8_t* col_ptr = dst_ptr + i; 793 vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); 794 vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); 795 vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); 796 vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); 797 vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); 798 vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); 799 vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); 800 vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); 801 } 802 } 803 } 804 }; 805 806 inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) { 807 int8x8x2_t a[4]; 808 a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]); 809 a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]); 810 a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]); 811 a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]); 812 int16x4x2_t b[4]; 813 b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]), 814 vreinterpret_s16_s8(a[1].val[0])); 815 b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]), 816 vreinterpret_s16_s8(a[1].val[1])); 817 b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]), 818 vreinterpret_s16_s8(a[3].val[0])); 819 b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]), 820 vreinterpret_s16_s8(a[3].val[1])); 821 int32x2x2_t c[4]; 822 c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]), 823 vreinterpret_s32_s16(b[2].val[0])); 824 c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]), 825 vreinterpret_s32_s16(b[3].val[0])); 826 c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]), 827 vreinterpret_s32_s16(b[2].val[1])); 828 c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]), 829 vreinterpret_s32_s16(b[3].val[1])); 830 RegBlockInt8<8, 8> result; 831 result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]); 832 result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]); 833 result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]); 834 result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]); 835 result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]); 836 result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]); 837 result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]); 838 result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]); 839 return result; 840 } 841 842 template <typename DstType> 843 struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> { 844 static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row, 845 int col) { 846 const auto& block = 847 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 848 std::int8_t* dst_ptr = dst->data(row, col); 849 int stride = dst->stride(); 850 for (int i = 0; i < 8; i++) { 851 vst1_s8(dst_ptr + i * stride, block.buf.reg[i]); 852 } 853 } 854 }; 855 856 template <typename DstType> 857 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { 858 static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, 859 int col) { 860 if (DstType::kOrder == MapOrder::ColMajor) { 861 vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]); 862 vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]); 863 vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]); 864 vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]); 865 vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]); 866 vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]); 867 vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]); 868 vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]); 869 } else { 870 int16x8x2_t a[4]; 871 a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]); 872 a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]); 873 a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]); 874 a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]); 875 int32x4x2_t b[4]; 876 b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]), 877 vreinterpretq_s32_s16(a[1].val[0])); 878 b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]), 879 vreinterpretq_s32_s16(a[1].val[1])); 880 b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]), 881 vreinterpretq_s32_s16(a[3].val[0])); 882 b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]), 883 vreinterpretq_s32_s16(a[3].val[1])); 884 vst1_s16(dst->data(row + 0, col + 0), 885 vget_low_s16(vreinterpretq_s16_s32(b[0].val[0]))); 886 vst1_s16(dst->data(row + 0, col + 4), 887 vget_low_s16(vreinterpretq_s16_s32(b[2].val[0]))); 888 vst1_s16(dst->data(row + 1, col + 0), 889 vget_low_s16(vreinterpretq_s16_s32(b[1].val[0]))); 890 vst1_s16(dst->data(row + 1, col + 4), 891 vget_low_s16(vreinterpretq_s16_s32(b[3].val[0]))); 892 vst1_s16(dst->data(row + 2, col + 0), 893 vget_low_s16(vreinterpretq_s16_s32(b[0].val[1]))); 894 vst1_s16(dst->data(row + 2, col + 4), 895 vget_low_s16(vreinterpretq_s16_s32(b[2].val[1]))); 896 vst1_s16(dst->data(row + 3, col + 0), 897 vget_low_s16(vreinterpretq_s16_s32(b[1].val[1]))); 898 vst1_s16(dst->data(row + 3, col + 4), 899 vget_low_s16(vreinterpretq_s16_s32(b[3].val[1]))); 900 vst1_s16(dst->data(row + 4, col + 0), 901 vget_high_s16(vreinterpretq_s16_s32(b[0].val[0]))); 902 vst1_s16(dst->data(row + 4, col + 4), 903 vget_high_s16(vreinterpretq_s16_s32(b[2].val[0]))); 904 vst1_s16(dst->data(row + 5, col + 0), 905 vget_high_s16(vreinterpretq_s16_s32(b[1].val[0]))); 906 vst1_s16(dst->data(row + 5, col + 4), 907 vget_high_s16(vreinterpretq_s16_s32(b[3].val[0]))); 908 vst1_s16(dst->data(row + 6, col + 0), 909 vget_high_s16(vreinterpretq_s16_s32(b[0].val[1]))); 910 vst1_s16(dst->data(row + 6, col + 4), 911 vget_high_s16(vreinterpretq_s16_s32(b[2].val[1]))); 912 vst1_s16(dst->data(row + 7, col + 0), 913 vget_high_s16(vreinterpretq_s16_s32(b[1].val[1]))); 914 vst1_s16(dst->data(row + 7, col + 4), 915 vget_high_s16(vreinterpretq_s16_s32(b[3].val[1]))); 916 } 917 } 918 }; 919 920 } // namespace gemmlowp 921 922 #endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 923