1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // pack_SSE.h: optimized SSE specializations of the templates in pack.h.
16
17 #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_
18 #define GEMMLOWP_INTERNAL_PACK_SSE_H_
19
20 #include <smmintrin.h>
21 #include "pack.h"
22
23 namespace gemmlowp {
24
25 // Requantizes source values pointed by raw_src_ptr in [0..255] range
26 // to the range specified by BitDepth, [0..((2^bits)-1)].
27 // This is in-place requantization, where the input is
28 // not modified if 8bit integers are used. SSE does not
29 // have less than 8bit kernels currently. Altought SSE registers
30 // hold 16 uint8_t elements, only first 8 uint8_t elements are
31 // requantized. The packing only use first 8 uint8_t elements
32 // of the SSE registers. Therefore, requantizing all 16 uint8_t
33 // elements will be wasteful computation.
34 template <typename QuantizationParams>
SSERequantize(__m128i * raw_src_ptr,ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode> * rounding_offset_generator)35 void SSERequantize(
36 __m128i* raw_src_ptr,
37 ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>*
38 rounding_offset_generator) {
39 static const int kBits = QuantizationParams::BitDepth::kBits;
40 static const std::uint8_t kMaxVal = (1 << kBits) - 1;
41 if (kBits == 8) {
42 return;
43 }
44
45 std::uint8_t* raw_src_ui8_ptr = (std::uint8_t*)&raw_src_ptr[0];
46
47 // modify only first 8 elements in the register (see note above)
48 for (int i = 0; i < 8; ++i) {
49 std::uint16_t scaled =
50 static_cast<std::uint16_t>(raw_src_ui8_ptr[i]) * kMaxVal;
51 std::uint8_t rounding_offset = rounding_offset_generator->get();
52 raw_src_ui8_ptr[i] = (scaled + rounding_offset) / 255;
53 }
54 }
55
56 // TODO: Add DepthMajorUint8SideMap
57
58 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
59 WidthMajorUint8SideMap;
60
61 template <int Cells>
62 using WidthMajorSideFormatNCells4x2 =
63 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
64
65 template <typename QuantizationParams, int Cells>
66 class PackingRegisterBlock<
67 QuantizationParams, WidthMajorUint8SideMap,
68 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
69 : public PackingRegisterBlockBase<
70 QuantizationParams, WidthMajorUint8SideMap,
71 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
72 public:
73 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
74 typedef typename KernelSideFormat::Cell CellFormat;
75 static const int kCells = KernelSideFormat::kCells;
76 static const int kCellWidth = CellFormat::kWidth;
77 static const int kKernelWidth = CellFormat::kWidth * kCells;
78 static const int kCellDepth = CellFormat::kDepth;
79 static const int kCellSize = CellFormat::kSize;
80
81 typedef ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>
82 RoundingOffsetGenerator;
83
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width,RoundingOffsetGenerator * rounding_offset_generator)84 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width,
85 RoundingOffsetGenerator* rounding_offset_generator) {
86 std::uint8_t* dst_ptr = dst->current_data();
87 const int width_stride = this->complete_src_.width_stride();
88 int depth_step = 8;
89
90 __m128i one = _mm_set1_epi16(1);
91 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
92 cell_start_depth += depth_step) {
93 for (int cell_start_width = 0; cell_start_width < kKernelWidth;
94 cell_start_width += kCellWidth) {
95 std::int32_t* cell_sums_of_each_slice_ptr =
96 dst->sums_of_each_slice() + start_width + cell_start_width;
97 const std::uint8_t* src_data =
98 this->complete_src_.data(cell_start_width, cell_start_depth);
99
100 __m128i xmm1 =
101 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0]));
102 __m128i xmm2 = _mm_loadl_epi64(
103 reinterpret_cast<const __m128i*>(&src_data[1 * width_stride]));
104 __m128i xmm3 = _mm_loadl_epi64(
105 reinterpret_cast<const __m128i*>(&src_data[2 * width_stride]));
106 __m128i xmm4 = _mm_loadl_epi64(
107 reinterpret_cast<const __m128i*>(&src_data[3 * width_stride]));
108
109 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
110 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
111
112 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
113 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
114
115 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
116 SSERequantize<QuantizationParams>(&xmm9, rounding_offset_generator);
117
118 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
119 SSERequantize<QuantizationParams>(&xmm10, rounding_offset_generator);
120
121 _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9);
122 _mm_storel_epi64(
123 reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10);
124
125 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
126 SSERequantize<QuantizationParams>(&xmm11, rounding_offset_generator);
127
128 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
129 SSERequantize<QuantizationParams>(&xmm12, rounding_offset_generator);
130
131 _mm_storel_epi64(
132 reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]),
133 xmm11);
134 _mm_storel_epi64(
135 reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]),
136 xmm12);
137
138 xmm1 = _mm_cvtepu8_epi16(xmm9);
139 xmm2 = _mm_madd_epi16(xmm1, one);
140 __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
141 reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0]));
142 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
143
144 xmm1 = _mm_cvtepu8_epi16(xmm10);
145 xmm2 = _mm_madd_epi16(xmm1, one);
146 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
147
148 xmm1 = _mm_cvtepu8_epi16(xmm11);
149 xmm2 = _mm_madd_epi16(xmm1, one);
150 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
151
152 xmm1 = _mm_cvtepu8_epi16(xmm12);
153 xmm2 = _mm_madd_epi16(xmm1, one);
154 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
155
156 _mm_storeu_si128(
157 reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]),
158 sums_of_each_slice_xmm);
159 dst_ptr += kCellSize;
160 }
161 dst_ptr += 3 * kCellSize * kCells;
162 }
163 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
164 }
165 };
166
167 } // namespace gemmlowp
168
169 #endif // GEMMLOWP_INTERNAL_PACK_SSE_H_
170