1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
16 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
17 #endif
18 #include "eight_bit_int_gemm.h"
19
20 #include <memory>
21
22 // gemmlowp symbols should have hidden visibility.
23 // currently this is ensured in the build system by
24 // passing -finlines-visibility-hidden. TODO: it would be
25 // safer to hardcode it here with some #pragma's.
26 #include "../public/gemmlowp.h"
27
28 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
29 // code. This code path consists of a number of meta-programmed, automatically
30 // generated GEMM kernels that are suitable for some sizes of input matrices.
31 // Due to the fact that the generated code relies heavily on loop unrolling,
32 // inling and currying of runtime parameters the size of the generated binary
33 // is quite significant (approx. 200kb) which might be prohibitive in
34 // low-memory situations.
35
36 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
37 #include "../meta/legacy_multi_thread_gemm.h"
38 #else
39
40 #if defined(GEMMLOWP_USE_META_FASTPATH)
41 #warning "META fast path turned on without NEON!"
42 #endif
43
44 #endif
45
46 namespace gemmlowp {
47 namespace eight_bit_int_gemm {
48 namespace {
49
50 // To be used as template parameter for GlobalLock.
51 // GlobalLock<EightBitIntGemmLockId> is the global lock
52 // on EightBitIntGemm entry points, protecting
53 // EightBitIntGemm's global state.
54 struct EightBitIntGemmLockId;
55
56 // Global state: consists of one global GemmContext instance.
57 GemmContext* global_context;
58
GetOrCreateGlobalContext()59 GemmContext* GetOrCreateGlobalContext() {
60 if (!global_context) {
61 global_context = new GemmContext;
62 }
63 return global_context;
64 }
65
DestroyGlobalContext()66 void DestroyGlobalContext() {
67 delete global_context;
68 global_context = nullptr;
69 }
70
71 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)72 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
73 const std::uint8_t* a, std::int32_t a_offset, int lda,
74 const std::uint8_t* b, std::int32_t b_offset, int ldb,
75 std::uint8_t* c, std::int32_t c_offset,
76 std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
77 BitDepthSetting bit_depth) {
78 const int lhs_offset = a_offset;
79 const int rhs_offset = b_offset;
80 const int result_offset = c_offset;
81 const int result_mult_int = c_mult_int;
82 const int result_shift = c_shift;
83
84 static const MapOrder ResultOrder =
85 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
86 static const MapOrder LhsOrder =
87 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
88 static const MapOrder RhsOrder =
89 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
90
91 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
92 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
93 MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
94
95 switch (bit_depth) {
96 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
97 case BitDepthSetting::BIT_DEPTH_SETTING: \
98 Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \
99 context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
100 result_mult_int, result_shift); \
101 return;
102 GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
103 GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
104 default:
105 abort();
106 #undef GEMMLOWP_HANDLE_BIT_DEPTH
107 }
108 }
109
110 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)111 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
112 const std::uint8_t* a, std::int32_t a_offset,
113 int lda, const std::uint8_t* b,
114 std::int32_t b_offset, int ldb, std::int32_t* c,
115 int ldc, BitDepthSetting bit_depth) {
116 const int lhs_offset = a_offset;
117 const int rhs_offset = b_offset;
118
119 static const MapOrder ResultOrder =
120 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
121 static const MapOrder LhsOrder =
122 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
123 static const MapOrder RhsOrder =
124 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
125
126 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
127 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
128 MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
129
130 auto empty_pipeline = std::make_tuple();
131
132 switch (bit_depth) {
133 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
134 case BitDepthSetting::BIT_DEPTH_SETTING: \
135 GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \
136 context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
137 return;
138 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
139 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
140 default:
141 abort();
142 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
143 }
144 }
145
146 class Scratch {
147 public:
Scratch()148 Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
149
AssureSize(std::int32_t required_size)150 void AssureSize(std::int32_t required_size) {
151 if (size_ >= required_size) {
152 return;
153 }
154 buffer_.reset(new std::uint8_t[required_size + 32]);
155 buffer_32_ =
156 buffer_.get() +
157 ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
158 assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
159 size_ = required_size;
160 }
161
Clear()162 void Clear() {
163 buffer_.reset(nullptr);
164 buffer_32_ = nullptr;
165 size_ = 0;
166 }
167
buffer()168 std::uint8_t* buffer() { return buffer_32_; }
169
170 private:
171 std::unique_ptr<std::uint8_t[]> buffer_;
172 std::uint8_t* buffer_32_;
173 std::int32_t size_;
174 };
175
176 Scratch* global_scratch = nullptr;
177
GetOrCreateGlobalScratch()178 Scratch* GetOrCreateGlobalScratch() {
179 if (global_scratch == nullptr) {
180 global_scratch = new Scratch();
181 }
182 return global_scratch;
183 }
184
DestroyGlobalScratch()185 void DestroyGlobalScratch() {
186 delete global_scratch;
187 global_scratch = nullptr;
188 }
189
190 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
191
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)192 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
193 // Is it row major and nicely packed?
194 if (transpose && stride == cols) {
195 return true;
196 }
197
198 // Is it a one row vector? (a vector is both row and column major)
199 if (rows == 1) {
200 return true;
201 }
202
203 return false;
204 }
205
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)206 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
207 // Is it column major and nicely packed?
208 if (!transpose && stride == rows) {
209 return true;
210 }
211
212 // Is it a one column vector? (a vector is both row and column major)
213 if (cols == 1) {
214 return true;
215 }
216
217 return false;
218 }
219
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)220 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
221 int m, int n, int k, int lda, int ldb, int ldc,
222 BitDepthSetting depth_setting) {
223 // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
224 if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
225 return false;
226 }
227
228 // The first operand needs to be a row major matrix or a vector.
229 if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
230 return false;
231 }
232
233 // The second operand needs to be a column major matrix or a vector.
234 if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
235 return false;
236 }
237
238 // The result can either be a row major matrix, a column major matrix or
239 // a vector.
240 if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
241 return true;
242 }
243
244 if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
245 return true;
246 }
247
248 return false;
249 }
250
251 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)252 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
253 const std::uint8_t* rhs, int m, int n, int k,
254 std::int32_t lhs_offset, std::int32_t rhs_offset,
255 std::int32_t sum_offset,
256 std::int32_t multiplicative_offset,
257 std::int32_t shift, bool result_transpose,
258 std::int32_t result_stride, std::uint8_t* result) {
259 Scratch* scratch = GetOrCreateGlobalScratch();
260 const std::int32_t max_num_threads = context->max_num_threads();
261 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
262 scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
263 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
264 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
265 rhs_offset, sum_offset, multiplicative_offset,
266 shift, result);
267 } else {
268 scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
269 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
270 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
271 lhs_offset, sum_offset, multiplicative_offset,
272 shift, result);
273 }
274 }
275
276 // Assure enough scratch memory is allocated and run the 8bit to float fast
277 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)278 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
279 const std::uint8_t* rhs, int m, int n, int k,
280 std::int32_t lhs_offset, std::int32_t rhs_offset,
281 float result_offset, bool result_transpose,
282 std::int32_t result_stride, float* result) {
283 Scratch* scratch = GetOrCreateGlobalScratch();
284 const std::int32_t max_num_threads = context->max_num_threads();
285 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
286 scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
287 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
288 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
289 rhs_offset, result_offset, result);
290 } else {
291 scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
292 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
293 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
294 lhs_offset, result_offset, result);
295 }
296 }
297
298 #endif
299
300 } // end anonymous namespace
301
302 // Public interface entry points
303
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)304 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
305 int m, int n, int k, const std::uint8_t* a,
306 std::int32_t a_offset, int lda, const std::uint8_t* b,
307 std::int32_t b_offset, int ldb, std::uint8_t* c,
308 std::int32_t c_offset, std::int32_t c_mult_int,
309 std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
310 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
311 GemmContext* context = GetOrCreateGlobalContext();
312
313 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
314 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
315 ldb, ldc, bit_depth)) {
316 MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
317 c_mult_int, c_shift, transpose_c, ldc, c);
318 return;
319 }
320 #endif
321
322 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \
323 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
324 EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \
325 b_offset, ldb, c, c_offset, c_mult_int, \
326 c_shift, ldc, bit_depth); \
327 }
328
329 GEMMLOWP_HANDLE_CASE(false, false, false)
330 GEMMLOWP_HANDLE_CASE(false, false, true)
331 GEMMLOWP_HANDLE_CASE(false, true, false)
332 GEMMLOWP_HANDLE_CASE(false, true, true)
333 GEMMLOWP_HANDLE_CASE(true, false, false)
334 GEMMLOWP_HANDLE_CASE(true, false, true)
335 GEMMLOWP_HANDLE_CASE(true, true, false)
336 GEMMLOWP_HANDLE_CASE(true, true, true)
337
338 #undef GEMMLOWP_HANDLE_CASE
339 }
340
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)341 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
342 int m, int n, int k, const std::uint8_t* a,
343 std::int32_t a_offset, std::int32_t lda,
344 const std::uint8_t* b, std::int32_t b_offset,
345 std::int32_t ldb, float* c, float c_offset,
346 std::int32_t ldc, BitDepthSetting bit_depth) {
347 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
348 GemmContext* context = GetOrCreateGlobalContext();
349
350 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
351 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
352 ldb, ldc, bit_depth)) {
353 MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
354 transpose_c, ldc, c);
355 return;
356 }
357 #endif
358
359 // TODO(maciekc): implement a float output stage, get rid of scratch memory.
360 Scratch* scratch = GetOrCreateGlobalScratch();
361 if (transpose_c) {
362 scratch->AssureSize(m * ldc * sizeof(std::int32_t));
363 } else {
364 scratch->AssureSize(n * ldc * sizeof(std::int32_t));
365 }
366 std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
367
368 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \
369 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
370 EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
371 b, b_offset, ldb, temp_c, ldc, \
372 bit_depth); \
373 }
374
375 GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
376 GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
377 GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
378 GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
379 GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
380 GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
381 GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
382 GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
383
384 #undef GEMMLOWP_HANDLE_INT32_CASE
385
386 if (transpose_c) {
387 // Row major.
388 for (int i = 0; i < m; ++i) {
389 float* dest_row = c + i * ldc;
390 std::int32_t* src_row = temp_c + i * ldc;
391 for (int j = 0; j < n; ++j) {
392 dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
393 }
394 }
395 } else {
396 // Column major.
397 for (int i = 0; i < n; ++i) {
398 float* dest_column = c + i * ldc;
399 std::int32_t* src_column = temp_c + i * ldc;
400 for (int j = 0; j < m; ++j) {
401 dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
402 }
403 }
404 }
405 }
406
SetMaxNumThreads(int n)407 void SetMaxNumThreads(int n) {
408 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
409 GemmContext* context = GetOrCreateGlobalContext();
410 context->set_max_num_threads(n);
411 }
412
FreePersistentResources()413 void FreePersistentResources() {
414 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
415 DestroyGlobalContext();
416 DestroyGlobalScratch();
417 }
418
419 } // namespace eight_bit_int_gemm
420 } // namespace gemmlowp
421