1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
16 #define GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
17 
18 #include "../internal/common.h"
19 
20 #ifdef GEMMLOWP_NEON
21 
22 #include "quantized_mul_kernels.h"
23 #include "single_thread_gemm.h"
24 #include "streams.h"
25 
26 namespace gemmlowp {
27 namespace meta {
28 
gemm_q8_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result,std::int32_t result_stride)29 void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
30                      const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
31                      std::int32_t k, std::int32_t lhs_offset,
32                      std::int32_t rhs_offset, std::int32_t result_offset,
33                      std::int32_t multiplicative_offset, std::int32_t shift,
34                      std::uint8_t* result, std::int32_t result_stride) {
35 #ifdef DEBUG
36 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
37   std::cout << "Legacy::GemmQ8." << std::endl;
38 #endif
39 #endif
40   typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
41                      RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
42       Params;
43   Params params;
44 
45   params.m = m;
46   params.n = n;
47   params.k = k;
48 
49   params.lhs = lhs;
50   params.rhs = rhs;
51   params.result = result;
52   params.scratch = scratch;
53 
54   params.left_stream.count = k;
55   params.left_stream.stride = k;
56   params.left_stream.multiplicative_sum_offset = rhs_offset;
57   params.left_stream.additive_sum_offset =
58       result_offset + k * lhs_offset * rhs_offset;
59 
60   params.right_stream.count = k;
61   params.right_stream.stride = k;
62   params.right_stream.multiplicative_sum_offset = lhs_offset;
63   params.right_stream.additive_sum_offset = 0;
64 
65   params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
66   params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
67   params.fused_kernel.kernel.shift = -shift;
68   params.fused_kernel.kernel.count = k;
69   params.fused_kernel.output_stream.stride = result_stride;
70 
71   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
72 }
73 
gemv_q8(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result)74 void gemv_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
75              const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
76              std::int32_t lhs_offset, std::int32_t rhs_offset,
77              std::int32_t result_offset, std::int32_t multiplicative_offset,
78              std::int32_t shift, std::uint8_t* result) {
79 #ifdef DEBUG
80 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
81   std::cout << "Legacy::GemvQ8." << std::endl;
82 #endif
83 #endif
84   typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
85                      RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
86       Params;
87   Params params;
88 
89   params.m = 1;
90   params.n = n;
91   params.k = k;
92 
93   params.lhs = lhs;
94   params.rhs = rhs;
95   params.result = result;
96   params.scratch = scratch;
97 
98   params.left_stream.count = k;
99   params.left_stream.stride = k;
100   params.left_stream.multiplicative_sum_offset = rhs_offset;
101   params.left_stream.additive_sum_offset =
102       result_offset + k * lhs_offset * rhs_offset;
103 
104   params.right_stream.count = k;
105   params.right_stream.stride = k;
106   params.right_stream.multiplicative_sum_offset = lhs_offset;
107   params.right_stream.additive_sum_offset = 0;
108 
109   params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
110   params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
111   params.fused_kernel.kernel.shift = -shift;
112   params.fused_kernel.kernel.count = k;
113   params.fused_kernel.output_stream.stride = n;
114 
115   if (k < 1536) {
116     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
117   } else {
118     Gemm<GemmExecutorPackLHS, Params, 2, 4, 8>(params);
119   }
120 }
121 
gemm_i32_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result,std::int32_t result_stride)122 void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
123                       const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
124                       std::int32_t k, std::int32_t lhs_offset,
125                       std::int32_t rhs_offset, std::int32_t* result,
126                       std::int32_t result_stride) {
127 #ifdef DEBUG
128 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
129   std::cout << "Legacy::GemmI32." << std::endl;
130 #endif
131 #endif
132   typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
133                      RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
134                      RowMajor>
135       Params;
136   Params params;
137 
138   params.m = m;
139   params.n = n;
140   params.k = k;
141 
142   params.lhs = lhs;
143   params.rhs = rhs;
144   params.result = result;
145   params.scratch = scratch;
146 
147   params.left_stream.count = k;
148   params.left_stream.stride = k;
149   params.left_stream.multiplicative_sum_offset = rhs_offset;
150   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
151 
152   params.right_stream.count = k;
153   params.right_stream.stride = k;
154   params.right_stream.multiplicative_sum_offset = lhs_offset;
155   params.right_stream.additive_sum_offset = 0;
156 
157   params.fused_kernel.kernel.count = k;
158   params.fused_kernel.output_stream.stride = result_stride * 4;
159 
160   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
161 }
162 
gemv_i32(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)163 void gemv_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
164               const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
165               std::int32_t lhs_offset, std::int32_t rhs_offset,
166               std::int32_t* result) {
167 #ifdef DEBUG
168 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
169   std::cout << "Legacy::GemvI32." << std::endl;
170 #endif
171 #endif
172   typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
173                      RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
174                      RowMajor>
175       Params;
176   Params params;
177 
178   params.m = 1;
179   params.n = n;
180   params.k = k;
181 
182   params.lhs = lhs;
183   params.rhs = rhs;
184   params.result = result;
185   params.scratch = scratch;
186 
187   params.left_stream.count = k;
188   params.left_stream.stride = k;
189   params.left_stream.multiplicative_sum_offset = rhs_offset;
190   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
191 
192   params.right_stream.count = k;
193   params.right_stream.stride = k;
194   params.right_stream.multiplicative_sum_offset = lhs_offset;
195   params.right_stream.additive_sum_offset = 0;
196 
197   params.fused_kernel.kernel.count = k;
198   params.fused_kernel.output_stream.stride = 0;
199 
200   if (k < 1664) {
201     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
202   } else {
203     Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
204   }
205 }
206 
gemm_f_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result,std::int32_t result_stride)207 void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
208                     const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
209                     std::int32_t k, std::int32_t lhs_offset,
210                     std::int32_t rhs_offset, float result_offset, float* result,
211                     std::int32_t result_stride) {
212 #ifdef DEBUG
213 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
214   std::cout << "Legacy::GemmF." << std::endl;
215 #endif
216 #endif
217   typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
218                      QuantizedStaticPreprocessedAsFloat, RowMajor>
219       Params;
220   Params params;
221 
222   params.m = m;
223   params.n = n;
224   params.k = k;
225 
226   params.lhs = lhs;
227   params.rhs = rhs;
228   params.result = result;
229   params.scratch = scratch;
230 
231   params.left_stream.count = k;
232   params.left_stream.stride = k;
233   params.left_stream.multiplicative_sum_offset = rhs_offset;
234   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
235 
236   params.right_stream.count = k;
237   params.right_stream.stride = k;
238   params.right_stream.multiplicative_sum_offset = lhs_offset;
239   params.right_stream.additive_sum_offset = 0;
240 
241   params.fused_kernel.kernel.count = k;
242   params.fused_kernel.kernel.scale = result_offset;
243   params.fused_kernel.output_stream.stride = result_stride * 4;
244 
245   Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
246 }
247 
gemv_f(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)248 void gemv_f(std::uint8_t* scratch, const std::uint8_t* lhs,
249             const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
250             std::int32_t lhs_offset, std::int32_t rhs_offset,
251             float result_offset, float* result) {
252 #ifdef DEBUG
253 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
254   std::cout << "Legacy::GemvF." << std::endl;
255 #endif
256 #endif
257   typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
258                      QuantizedStaticPreprocessedAsFloat, RowMajor>
259       Params;
260   Params params;
261 
262   params.m = 1;
263   params.n = n;
264   params.k = k;
265 
266   params.lhs = lhs;
267   params.rhs = rhs;
268   params.result = result;
269   params.scratch = scratch;
270 
271   params.left_stream.count = k;
272   params.left_stream.stride = k;
273   params.left_stream.multiplicative_sum_offset = rhs_offset;
274   params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
275 
276   params.right_stream.count = k;
277   params.right_stream.stride = k;
278   params.right_stream.multiplicative_sum_offset = lhs_offset;
279   params.right_stream.additive_sum_offset = 0;
280 
281   params.fused_kernel.kernel.count = k;
282   params.fused_kernel.kernel.scale = result_offset;
283   params.fused_kernel.output_stream.stride = 0;
284 
285   if (k < 1664) {
286     Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
287   } else {
288     Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
289   }
290 }
291 
292 }  // namespace meta
293 }  // namespace gemmlowp
294 
295 #else
296 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
297 #endif
298 
299 #endif  // GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
300