1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
16 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
17 #endif
18 #include "eight_bit_int_gemm.h"
19 
20 #include <memory>
21 
22 // gemmlowp symbols should have hidden visibility.
23 // currently this is ensured in the build system by
24 // passing -finlines-visibility-hidden. TODO: it would be
25 // safer to hardcode it here with some #pragma's.
26 #include "../public/gemmlowp.h"
27 
28 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
29 // code. This code path consists of a number of meta-programmed, automatically
30 // generated GEMM kernels that are suitable for some sizes of input matrices.
31 // Due to the fact that the generated code relies heavily on loop unrolling,
32 // inling and currying of runtime parameters the size of the generated binary
33 // is quite significant (approx. 200kb) which might be prohibitive in
34 // low-memory situations.
35 
36 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
37 #include "../meta/legacy_multi_thread_gemm.h"
38 #else
39 
40 #if defined(GEMMLOWP_USE_META_FASTPATH)
41 #warning "META fast path turned on without NEON!"
42 #endif
43 
44 #endif
45 
46 namespace gemmlowp {
47 namespace eight_bit_int_gemm {
48 namespace {
49 
50 // To be used as template parameter for GlobalLock.
51 // GlobalLock<EightBitIntGemmLockId> is the global lock
52 // on EightBitIntGemm entry points, protecting
53 // EightBitIntGemm's global state.
54 struct EightBitIntGemmLockId;
55 
56 // Global state: consists of one global GemmContext instance.
57 GemmContext* global_context;
58 
GetOrCreateGlobalContext()59 GemmContext* GetOrCreateGlobalContext() {
60   if (!global_context) {
61     global_context = new GemmContext;
62   }
63   return global_context;
64 }
65 
DestroyGlobalContext()66 void DestroyGlobalContext() {
67   delete global_context;
68   global_context = nullptr;
69 }
70 
71 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)72 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
73                          const std::uint8_t* a, std::int32_t a_offset, int lda,
74                          const std::uint8_t* b, std::int32_t b_offset, int ldb,
75                          std::uint8_t* c, std::int32_t c_offset,
76                          std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
77                          BitDepthSetting bit_depth) {
78   const int lhs_offset = a_offset;
79   const int rhs_offset = b_offset;
80   const int result_offset = c_offset;
81   const int result_mult_int = c_mult_int;
82   const int result_shift = c_shift;
83 
84   static const MapOrder ResultOrder =
85       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
86   static const MapOrder LhsOrder =
87       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
88   static const MapOrder RhsOrder =
89       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
90 
91   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
92   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
93   MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
94 
95   switch (bit_depth) {
96 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
97   case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
98     Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
99         context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
100         result_mult_int, result_shift);                                    \
101     return;
102     GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
103     GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
104     default:
105       abort();
106 #undef GEMMLOWP_HANDLE_BIT_DEPTH
107   }
108 }
109 
110 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)111 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
112                               const std::uint8_t* a, std::int32_t a_offset,
113                               int lda, const std::uint8_t* b,
114                               std::int32_t b_offset, int ldb, std::int32_t* c,
115                               int ldc, BitDepthSetting bit_depth) {
116   const int lhs_offset = a_offset;
117   const int rhs_offset = b_offset;
118 
119   static const MapOrder ResultOrder =
120       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
121   static const MapOrder LhsOrder =
122       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
123   static const MapOrder RhsOrder =
124       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
125 
126   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
127   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
128   MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
129 
130   auto empty_pipeline = std::make_tuple();
131 
132   switch (bit_depth) {
133 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
134   case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
135     GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
136         context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
137     return;
138     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
139     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
140     default:
141       abort();
142 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
143   }
144 }
145 
146 class Scratch {
147  public:
Scratch()148   Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
149 
AssureSize(std::int32_t required_size)150   void AssureSize(std::int32_t required_size) {
151     if (size_ >= required_size) {
152       return;
153     }
154     buffer_.reset(new std::uint8_t[required_size + 32]);
155     buffer_32_ =
156         buffer_.get() +
157         ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
158     assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
159     size_ = required_size;
160   }
161 
Clear()162   void Clear() {
163     buffer_.reset(nullptr);
164     buffer_32_ = nullptr;
165     size_ = 0;
166   }
167 
buffer()168   std::uint8_t* buffer() { return buffer_32_; }
169 
170  private:
171   std::unique_ptr<std::uint8_t[]> buffer_;
172   std::uint8_t* buffer_32_;
173   std::int32_t size_;
174 };
175 
176 Scratch* global_scratch = nullptr;
177 
GetOrCreateGlobalScratch()178 Scratch* GetOrCreateGlobalScratch() {
179   if (global_scratch == nullptr) {
180     global_scratch = new Scratch();
181   }
182   return global_scratch;
183 }
184 
DestroyGlobalScratch()185 void DestroyGlobalScratch() {
186   delete global_scratch;
187   global_scratch = nullptr;
188 }
189 
190 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
191 
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)192 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
193   // Is it row major and nicely packed?
194   if (transpose && stride == cols) {
195     return true;
196   }
197 
198   // Is it a one row vector? (a vector is both row and column major)
199   if (rows == 1) {
200     return true;
201   }
202 
203   return false;
204 }
205 
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)206 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
207   // Is it column major and nicely packed?
208   if (!transpose && stride == rows) {
209     return true;
210   }
211 
212   // Is it a one column vector? (a vector is both row and column major)
213   if (cols == 1) {
214     return true;
215   }
216 
217   return false;
218 }
219 
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)220 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
221                            int m, int n, int k, int lda, int ldb, int ldc,
222                            BitDepthSetting depth_setting) {
223   // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
224   if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
225     return false;
226   }
227 
228   // The first operand needs to be a row major matrix or a vector.
229   if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
230     return false;
231   }
232 
233   // The second operand needs to be a column major matrix or a vector.
234   if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
235     return false;
236   }
237 
238   // The result can either be a row major matrix, a column major matrix or
239   // a vector.
240   if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
241     return true;
242   }
243 
244   if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
245     return true;
246   }
247 
248   return false;
249 }
250 
251 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)252 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
253                            const std::uint8_t* rhs, int m, int n, int k,
254                            std::int32_t lhs_offset, std::int32_t rhs_offset,
255                            std::int32_t sum_offset,
256                            std::int32_t multiplicative_offset,
257                            std::int32_t shift, bool result_transpose,
258                            std::int32_t result_stride, std::uint8_t* result) {
259   Scratch* scratch = GetOrCreateGlobalScratch();
260   const std::int32_t max_num_threads = context->max_num_threads();
261   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
262     scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
263     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
264                                scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
265                                rhs_offset, sum_offset, multiplicative_offset,
266                                shift, result);
267   } else {
268     scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
269     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
270                                scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
271                                lhs_offset, sum_offset, multiplicative_offset,
272                                shift, result);
273   }
274 }
275 
276 // Assure enough scratch memory is allocated and run the 8bit to float fast
277 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)278 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
279                    const std::uint8_t* rhs, int m, int n, int k,
280                    std::int32_t lhs_offset, std::int32_t rhs_offset,
281                    float result_offset, bool result_transpose,
282                    std::int32_t result_stride, float* result) {
283   Scratch* scratch = GetOrCreateGlobalScratch();
284   const std::int32_t max_num_threads = context->max_num_threads();
285   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
286     scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
287     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
288                               scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
289                               rhs_offset, result_offset, result);
290   } else {
291     scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
292     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
293                               scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
294                               lhs_offset, result_offset, result);
295   }
296 }
297 
298 #endif
299 
300 }  // end anonymous namespace
301 
302 // Public interface entry points
303 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)304 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
305                      int m, int n, int k, const std::uint8_t* a,
306                      std::int32_t a_offset, int lda, const std::uint8_t* b,
307                      std::int32_t b_offset, int ldb, std::uint8_t* c,
308                      std::int32_t c_offset, std::int32_t c_mult_int,
309                      std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
310   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
311   GemmContext* context = GetOrCreateGlobalContext();
312 
313 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
314   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
315                             ldb, ldc, bit_depth)) {
316     MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
317                           c_mult_int, c_shift, transpose_c, ldc, c);
318     return;
319   }
320 #endif
321 
322 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
323   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
324     EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
325                                     b_offset, ldb, c, c_offset, c_mult_int, \
326                                     c_shift, ldc, bit_depth);               \
327   }
328 
329   GEMMLOWP_HANDLE_CASE(false, false, false)
330   GEMMLOWP_HANDLE_CASE(false, false, true)
331   GEMMLOWP_HANDLE_CASE(false, true, false)
332   GEMMLOWP_HANDLE_CASE(false, true, true)
333   GEMMLOWP_HANDLE_CASE(true, false, false)
334   GEMMLOWP_HANDLE_CASE(true, false, true)
335   GEMMLOWP_HANDLE_CASE(true, true, false)
336   GEMMLOWP_HANDLE_CASE(true, true, true)
337 
338 #undef GEMMLOWP_HANDLE_CASE
339 }
340 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)341 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
342                      int m, int n, int k, const std::uint8_t* a,
343                      std::int32_t a_offset, std::int32_t lda,
344                      const std::uint8_t* b, std::int32_t b_offset,
345                      std::int32_t ldb, float* c, float c_offset,
346                      std::int32_t ldc, BitDepthSetting bit_depth) {
347   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
348   GemmContext* context = GetOrCreateGlobalContext();
349 
350 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
351   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
352                             ldb, ldc, bit_depth)) {
353     MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
354                   transpose_c, ldc, c);
355     return;
356   }
357 #endif
358 
359   // TODO(maciekc): implement a float output stage, get rid of scratch memory.
360   Scratch* scratch = GetOrCreateGlobalScratch();
361   if (transpose_c) {
362     scratch->AssureSize(m * ldc * sizeof(std::int32_t));
363   } else {
364     scratch->AssureSize(n * ldc * sizeof(std::int32_t));
365   }
366   std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
367 
368 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
369   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
370     EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
371                                          b, b_offset, ldb, temp_c, ldc,      \
372                                          bit_depth);                         \
373   }
374 
375   GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
376   GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
377   GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
378   GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
379   GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
380   GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
381   GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
382   GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
383 
384 #undef GEMMLOWP_HANDLE_INT32_CASE
385 
386   if (transpose_c) {
387     // Row major.
388     for (int i = 0; i < m; ++i) {
389       float* dest_row = c + i * ldc;
390       std::int32_t* src_row = temp_c + i * ldc;
391       for (int j = 0; j < n; ++j) {
392         dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
393       }
394     }
395   } else {
396     // Column major.
397     for (int i = 0; i < n; ++i) {
398       float* dest_column = c + i * ldc;
399       std::int32_t* src_column = temp_c + i * ldc;
400       for (int j = 0; j < m; ++j) {
401         dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
402       }
403     }
404   }
405 }
406 
SetMaxNumThreads(int n)407 void SetMaxNumThreads(int n) {
408   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
409   GemmContext* context = GetOrCreateGlobalContext();
410   context->set_max_num_threads(n);
411 }
412 
FreePersistentResources()413 void FreePersistentResources() {
414   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
415   DestroyGlobalContext();
416   DestroyGlobalScratch();
417 }
418 
419 }  // namespace eight_bit_int_gemm
420 }  // namespace gemmlowp
421