1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // multi_thread_gemm.h: Entry point to the multithreaded version of the
16 // generated (meta) gemm library.
17 
18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
19 #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
20 
21 #ifdef GEMMLOWP_NEON_32
22 
23 #include "multi_thread_common.h"
24 #include "single_thread_gemm.h"
25 
26 namespace gemmlowp {
27 namespace meta {
28 namespace internal {
29 
30 const std::int32_t kMaxCacheFriendlySize = 24 * 1024;
31 
32 template <typename IN_TYPE, typename OUT_TYPE, typename F>
CacheFriendlyMatrixMatrix(std::uint8_t * scratch,const IN_TYPE * lhs,const IN_TYPE * rhs,std::int32_t m,std::int32_t n,std::int32_t k,OUT_TYPE * result,std::int32_t result_stride,const F & operation)33 void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs,
34                                const IN_TYPE* rhs, std::int32_t m,
35                                std::int32_t n, std::int32_t k, OUT_TYPE* result,
36                                std::int32_t result_stride, const F& operation) {
37   const std::int32_t rhs_size = n * k * sizeof(IN_TYPE);
38   if (rhs_size > kMaxCacheFriendlySize) {
39     const std::int32_t optimal_n =
40         std::max(1, 3 * (kMaxCacheFriendlySize / (k * 3)));
41     const std::int32_t chunks_count_less_one = n / optimal_n - 1;
42     const std::int32_t chunk_size = optimal_n * k;
43     for (int i = 0; i < chunks_count_less_one; ++i) {
44       operation.ExecuteCacheFriendlyMatrixMatrix(
45           scratch, lhs, rhs + i * chunk_size, m, optimal_n, k,
46           result + i * optimal_n, result_stride);
47     }
48     const std::int32_t n_left = n - chunks_count_less_one * optimal_n;
49     operation.ExecuteCacheFriendlyMatrixMatrix(
50         scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k,
51         result + chunks_count_less_one * optimal_n, result_stride);
52   } else {
53     operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k,
54                                                result, result_stride);
55   }
56 }
57 
58 class GemmQuantized8BitOperation {
59  public:
GemmQuantized8BitOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift)60   GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
61                              std::int32_t sum_offset, std::int32_t multiplier,
62                              std::int32_t shift)
63       : lhs_offset(lhs_offset),
64         rhs_offset(rhs_offset),
65         sum_offset(sum_offset),
66         multiplier(multiplier),
67         shift(shift) {}
68 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)69   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
70                            const std::uint8_t* rhs, std::int32_t m,
71                            std::int32_t n, std::int32_t k, std::uint8_t* result,
72                            std::int32_t result_stride) const {
73     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
74                               *this);
75   }
76 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)77   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
78                                         const std::uint8_t* lhs,
79                                         const std::uint8_t* rhs, std::int32_t m,
80                                         std::int32_t n, std::int32_t k,
81                                         std::uint8_t* result,
82                                         std::int32_t result_stride) const {
83     gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
84                     sum_offset, multiplier, shift, result, result_stride);
85   }
86 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)87   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
88                                        std::int32_t k) {
89     return 128 * 1024;
90   }
91 
92  private:
93   std::int32_t lhs_offset;
94   std::int32_t rhs_offset;
95   std::int32_t sum_offset;
96   std::int32_t multiplier;
97   std::int32_t shift;
98 };
99 
100 class GemmFloatOperation {
101  public:
GemmFloatOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)102   GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
103                      float result_offset)
104       : lhs_offset(lhs_offset),
105         rhs_offset(rhs_offset),
106         result_offset(result_offset) {}
107 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)108   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
109                            const std::uint8_t* rhs, std::int32_t m,
110                            std::int32_t n, std::int32_t k, float* result,
111                            std::int32_t result_stride) const {
112     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
113                               *this);
114   }
115 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)116   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
117                                         const std::uint8_t* lhs,
118                                         const std::uint8_t* rhs, std::int32_t m,
119                                         std::int32_t n, std::int32_t k,
120                                         float* result,
121                                         std::int32_t result_stride) const {
122     gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
123                    result_offset, result, result_stride);
124   }
125 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)126   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
127                                        std::int32_t k) {
128     return 128 * 1024;
129   }
130 
131  private:
132   std::int32_t lhs_offset;
133   std::int32_t rhs_offset;
134   float result_offset;
135 };
136 
137 class GemmInt32Operation {
138  public:
GemmInt32Operation(std::int32_t lhs_offset,std::int32_t rhs_offset)139   GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
140       : lhs_offset(lhs_offset), rhs_offset(rhs_offset) {}
141 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)142   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
143                            const std::uint8_t* rhs, std::int32_t m,
144                            std::int32_t n, std::int32_t k, std::int32_t* result,
145                            std::int32_t result_stride) const {
146     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
147                               *this);
148   }
149 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)150   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
151                                         const std::uint8_t* lhs,
152                                         const std::uint8_t* rhs, std::int32_t m,
153                                         std::int32_t n, std::int32_t k,
154                                         std::int32_t* result,
155                                         std::int32_t result_stride) const {
156     gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
157                      result_stride);
158   }
159 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)160   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
161                                        std::int32_t k) {
162     return 128 * 1024;
163   }
164 
165  private:
166   std::int32_t lhs_offset;
167   std::int32_t rhs_offset;
168 };
169 
170 }  // namespace internal
171 
gemm_q8_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)172 std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
173                              std::int32_t max_threads) {
174   return internal::ResolveMaxThreads(max_threads) *
175          internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k);
176 }
177 
multi_thread_gemm_q8(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift,std::uint8_t * result)178 void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
179                           std::uint8_t* scratch, const std::uint8_t* lhs,
180                           const std::uint8_t* rhs, std::int32_t m,
181                           std::int32_t n, std::int32_t k,
182                           std::int32_t lhs_offset, std::int32_t rhs_offset,
183                           std::int32_t sum_offset, std::int32_t multiplier,
184                           std::int32_t shift, std::uint8_t* result) {
185   internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset,
186                                                  sum_offset, multiplier, shift);
187   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
188                                       n, k, result, n, operation);
189 }
190 
gemm_f_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)191 std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
192                             std::int32_t max_threads) {
193   return internal::ResolveMaxThreads(max_threads) *
194          internal::GemmFloatOperation::ScratchPerThread(m, n, k);
195 }
196 
multi_thread_gemm_f(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)197 void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
198                          std::uint8_t* scratch, const std::uint8_t* lhs,
199                          const std::uint8_t* rhs, std::int32_t m,
200                          std::int32_t n, std::int32_t k,
201                          std::int32_t lhs_offset, std::int32_t rhs_offset,
202                          float result_offset, float* result) {
203   internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset);
204   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
205                                       n, k, result, n, operation);
206 }
207 
gemm_i32_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)208 std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
209                               std::int32_t max_threads) {
210   return internal::ResolveMaxThreads(max_threads) *
211          internal::GemmInt32Operation::ScratchPerThread(m, n, k);
212 }
213 
multi_thread_gemm_i32(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)214 void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool,
215                            std::int32_t max_threads, std::uint8_t* scratch,
216                            const std::uint8_t* lhs, const std::uint8_t* rhs,
217                            std::int32_t m, std::int32_t n, std::int32_t k,
218                            std::int32_t lhs_offset, std::int32_t rhs_offset,
219                            std::int32_t* result) {
220   internal::GemmInt32Operation operation(lhs_offset, rhs_offset);
221   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
222                                       n, k, result, n, operation);
223 }
224 
225 }  // namespace meta
226 }  // namespace gemmlowp
227 
228 #else
229 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
230 #endif
231 
232 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
233