1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <unistd.h>
16 #ifdef __APPLE__
17 #include <sys/time.h>
18 #endif
19 
20 #include <cstdint>
21 #include <cstdlib>
22 #include <ctime>
23 #include <iomanip>
24 #include <iostream>
25 #include <map>
26 #include <memory>
27 #include <vector>
28 
29 #include "multi_thread_gemm.h"
30 #include "quantized_mul_kernels.h"
31 #include "single_thread_gemm.h"
32 #include "streams.h"
33 
34 #define LHS_OFFSET (-127)
35 #define RHS_OFFSET (-127)
36 #define SUM_OFFSET (127)
37 #define MUL_OFFSET (1)
38 #define SHIFT (7)
39 #define FLOAT_SCALE (0.333f)
40 
41 using namespace gemmlowp::meta;
42 
43 // Input, output & kernel setups.
44 
45 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, ColumnMajorWithSum,
46                    QuantizedStaticPreprocessed, RowMajor>
47     ParamsColumnMajor;
48 
49 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum,
50                    QuantizedStaticPreprocessed, RowMajor>
51     ParamsRowMajor;
52 
53 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, ColumnMajorWithSum,
54                    QuantizedStaticPreprocessedAsFloat, RowMajor>
55     ParamsColumnMajorAsFloat;
56 
57 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
58                    QuantizedStaticPreprocessedAsFloat, RowMajor>
59     ParamsRowMajorAsFloat;
60 
61 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, ColumnMajorWithSum,
62                    QuantizedStaticPreprocessedAsInt32, RowMajor>
63     ParamsColumnMajorAsInt32;
64 
65 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum,
66                    QuantizedStaticPreprocessedAsInt32, RowMajor>
67     ParamsRowMajorAsInt32;
68 
69 typedef gemmlowp::WorkersPool Pool;
70 typedef SimpleContext<gemmlowp::WorkersPool> Context;
71 
72 #ifdef LHS_PACK
73 typedef GemmExecutorPackLHSCacheFriendly<> Executor;
74 #else
75 typedef GemmExecutorPackRHSCacheFriendly<> Executor;
76 #endif
77 
78 // Testing helper functions.
79 
prepare_test_data(std::uint8_t * data,std::int32_t rows,std::int32_t cols,std::int32_t seed,std::int32_t seed_2)80 void prepare_test_data(std::uint8_t* data, std::int32_t rows, std::int32_t cols,
81                        std::int32_t seed, std::int32_t seed_2) {
82   std::int32_t value = seed;
83   for (int i = 0; i < rows * cols; ++i) {
84     data[i] = static_cast<std::uint8_t>(value);
85     value = ((value * seed_2) + seed) % 256;
86   }
87 }
88 
89 template <typename CLEAR_TYPE>
clear(int rows,int cols,CLEAR_TYPE * data)90 void clear(int rows, int cols, CLEAR_TYPE* data) {
91   for (int i = 0; i < rows * cols; ++i) {
92     data[i] = 0;
93   }
94 }
95 
check_row_row(std::uint8_t * lhs,std::uint8_t * rhs,std::uint8_t * results,int rows,int cols,int depth)96 bool check_row_row(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
97                    int cols, int depth) {
98   int wrong = 0;
99   int rounding = (1 << (SHIFT - 1));
100   for (int i = 0; i < rows; ++i) {
101     for (int j = 0; j < cols; ++j) {
102       int expected = 0;
103       for (int k = 0; k < depth; ++k) {
104         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
105                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
106       }
107       expected += SUM_OFFSET * depth;
108       expected *= MUL_OFFSET;
109       expected += rounding;
110       expected = (expected >> SHIFT);
111       if (expected < 0) {
112         expected = 0;
113       } else if (expected > 255) {
114         expected = 255;
115       }
116       expected = static_cast<int>(static_cast<std::uint8_t>(expected));
117       int actual = static_cast<int>(results[i * cols + j]);
118       if (actual != expected) {
119         std::cout << "Wrong @" << i << "x" << j << " : " << actual
120                   << " != " << expected << std::endl;
121         wrong++;
122       }
123     }
124   }
125   if (wrong != 0) {
126     std::cout << wrong << "/" << (rows * cols) << std::endl;
127   }
128   return wrong == 0;
129 }
130 
check_row_col(std::uint8_t * lhs,std::uint8_t * rhs,std::uint8_t * results,int rows,int cols,int depth)131 bool check_row_col(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
132                    int cols, int depth) {
133   int wrong = 0;
134   int rounding = (1 << (SHIFT - 1));
135   for (int i = 0; i < rows; ++i) {
136     for (int j = 0; j < cols; ++j) {
137       int expected = 0;
138       for (int k = 0; k < depth; ++k) {
139         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
140                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
141       }
142       expected += SUM_OFFSET * depth;
143       expected *= MUL_OFFSET;
144       expected += rounding;
145       expected = (expected >> SHIFT);
146       if (expected < 0) {
147         expected = 0;
148       } else if (expected > 255) {
149         expected = 255;
150       }
151       expected = static_cast<int>(static_cast<std::uint8_t>(expected));
152       int actual = static_cast<int>(results[i * cols + j]);
153       if (actual != expected) {
154         wrong++;
155       }
156     }
157   }
158   return wrong == 0;
159 }
160 
check_row_row_f(std::uint8_t * lhs,std::uint8_t * rhs,float * results,int rows,int cols,int depth)161 bool check_row_row_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
162                      int cols, int depth) {
163   int wrong = 0;
164   for (int i = 0; i < rows; ++i) {
165     for (int j = 0; j < cols; ++j) {
166       int expected = 0;
167       for (int k = 0; k < depth; ++k) {
168         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
169                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
170       }
171       float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
172       float actual = results[i * cols + j];
173       if (actual != expected_float) {
174         wrong++;
175       }
176     }
177   }
178   return wrong == 0;
179 }
180 
check_row_col_f(std::uint8_t * lhs,std::uint8_t * rhs,float * results,int rows,int cols,int depth)181 bool check_row_col_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
182                      int cols, int depth) {
183   int wrong = 0;
184   for (int i = 0; i < rows; ++i) {
185     for (int j = 0; j < cols; ++j) {
186       int expected = 0;
187       for (int k = 0; k < depth; ++k) {
188         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
189                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
190       }
191       float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
192       float actual = results[i * cols + j];
193       if (actual != expected_float) {
194         wrong++;
195       }
196     }
197   }
198   return wrong == 0;
199 }
200 
check_row_row_i32(std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t * results,int rows,int cols,int depth)201 bool check_row_row_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
202                        int cols, int depth) {
203   int wrong = 0;
204   for (int i = 0; i < rows; ++i) {
205     for (int j = 0; j < cols; ++j) {
206       int expected = 0;
207       for (int k = 0; k < depth; ++k) {
208         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
209                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
210       }
211       int actual = results[i * cols + j];
212       if (actual != expected) {
213         wrong++;
214       }
215     }
216   }
217   return wrong == 0;
218 }
219 
check_row_col_i32(std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t * results,int rows,int cols,int depth)220 bool check_row_col_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
221                        int cols, int depth) {
222   int wrong = 0;
223   for (int i = 0; i < rows; ++i) {
224     for (int j = 0; j < cols; ++j) {
225       int expected = 0;
226       for (int k = 0; k < depth; ++k) {
227         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
228                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
229       }
230       int actual = results[i * cols + j];
231       if (actual != expected) {
232         wrong++;
233       }
234     }
235   }
236   return wrong == 0;
237 }
238 
239 template <typename PARAMS, typename RESULT_TYPE>
setup_params(std::uint8_t * lhs,std::uint8_t * rhs,RESULT_TYPE * result,std::uint8_t * scratch,PARAMS * params)240 void setup_params(std::uint8_t* lhs, std::uint8_t* rhs, RESULT_TYPE* result,
241                   std::uint8_t* scratch, PARAMS* params) {
242   params->lhs = lhs;
243   params->rhs = rhs;
244   params->result = result;
245   params->scratch = scratch;
246 
247   params->left_stream.multiplicative_sum_offset = RHS_OFFSET;
248   params->left_stream.additive_sum_offset = 0;
249 
250   params->right_stream.multiplicative_sum_offset = LHS_OFFSET;
251   params->right_stream.additive_sum_offset = 0;
252 }
253 
setup_row_row(int m,int n,int k,ParamsRowMajor * params)254 void setup_row_row(int m, int n, int k, ParamsRowMajor* params) {
255   params->m = m;
256   params->n = n;
257   params->k = k;
258   params->left_stream.count = k;
259   params->left_stream.stride = k;
260   params->left_stream.additive_sum_offset =
261       SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
262   params->right_stream.count = k;
263   params->right_stream.stride = k;
264   params->fused_kernel.kernel.count = k;
265   params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
266   params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
267   params->fused_kernel.kernel.shift = -SHIFT;
268   params->fused_kernel.output_stream.stride = n;
269 }
270 
setup_row_col(int m,int n,int k,ParamsColumnMajor * params)271 void setup_row_col(int m, int n, int k, ParamsColumnMajor* params) {
272   params->m = m;
273   params->n = n;
274   params->k = k;
275   params->left_stream.count = k;
276   params->left_stream.stride = k;
277   params->left_stream.additive_sum_offset =
278       SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
279   params->right_stream.count = k;
280   params->right_stream.stride = n;
281   params->fused_kernel.kernel.count = k;
282   params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
283   params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
284   params->fused_kernel.kernel.shift = -SHIFT;
285   params->fused_kernel.output_stream.stride = n;
286 }
287 
setup_row_row_f(int m,int n,int k,ParamsRowMajorAsFloat * params)288 void setup_row_row_f(int m, int n, int k, ParamsRowMajorAsFloat* params) {
289   params->m = m;
290   params->n = n;
291   params->k = k;
292   params->left_stream.count = k;
293   params->left_stream.stride = k;
294   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
295   params->right_stream.count = k;
296   params->right_stream.stride = k;
297   params->fused_kernel.kernel.count = k;
298   params->fused_kernel.kernel.scale = FLOAT_SCALE;
299   params->fused_kernel.output_stream.stride = n * sizeof(float);
300 }
301 
setup_row_col_f(int m,int n,int k,ParamsColumnMajorAsFloat * params)302 void setup_row_col_f(int m, int n, int k, ParamsColumnMajorAsFloat* params) {
303   params->m = m;
304   params->n = n;
305   params->k = k;
306   params->left_stream.count = k;
307   params->left_stream.stride = k;
308   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
309   params->right_stream.count = k;
310   params->right_stream.stride = n;
311   params->fused_kernel.kernel.count = k;
312   params->fused_kernel.kernel.scale = FLOAT_SCALE;
313   params->fused_kernel.output_stream.stride = n * sizeof(float);
314 }
315 
setup_row_row_i32(int m,int n,int k,ParamsRowMajorAsInt32 * params)316 void setup_row_row_i32(int m, int n, int k, ParamsRowMajorAsInt32* params) {
317   params->m = m;
318   params->n = n;
319   params->k = k;
320   params->left_stream.count = k;
321   params->left_stream.stride = k;
322   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
323   params->right_stream.count = k;
324   params->right_stream.stride = k;
325   params->fused_kernel.kernel.count = k;
326   params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
327 }
328 
setup_row_col_i32(int m,int n,int k,ParamsColumnMajorAsInt32 * params)329 void setup_row_col_i32(int m, int n, int k, ParamsColumnMajorAsInt32* params) {
330   params->m = m;
331   params->n = n;
332   params->k = k;
333   params->left_stream.count = k;
334   params->left_stream.stride = k;
335   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
336   params->right_stream.count = k;
337   params->right_stream.stride = n;
338   params->fused_kernel.kernel.count = k;
339   params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
340 }
341 
main()342 int main() {
343   ParamsRowMajor params_row;
344   ParamsColumnMajor params_col;
345   ParamsRowMajorAsFloat params_row_f;
346   ParamsColumnMajorAsFloat params_col_f;
347   ParamsRowMajorAsInt32 params_row_i32;
348   ParamsColumnMajorAsInt32 params_col_i32;
349 
350   std::unique_ptr<std::uint8_t> lhs(new std::uint8_t[1024 * 1024]);
351   std::unique_ptr<std::uint8_t> rhs(new std::uint8_t[1024 * 1024]);
352   std::unique_ptr<std::uint8_t> result(new std::uint8_t[1024 * 1024]);
353   std::unique_ptr<float> result_f(new float[1024 * 1024]);
354   std::unique_ptr<std::int32_t> result_i32(new std::int32_t[1024 * 1024]);
355   std::unique_ptr<std::uint8_t> scratch(new std::uint8_t[4048 * 1024]);
356 
357   setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), &params_row);
358   setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), &params_col);
359   setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
360                &params_row_f);
361   setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
362                &params_col_f);
363   setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
364                &params_row_i32);
365   setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
366                &params_col_i32);
367 
368   Pool pool;
369   Context context(4, &pool);
370 
371   for (int i = 1; i < 16; ++i) {
372     for (int j = 1; j < 16; ++j) {
373       for (int k = 1; k < 24; ++k) {
374         prepare_test_data(lhs.get(), i, k, 11, 13);
375         prepare_test_data(rhs.get(), j, k, 13, 17);
376 
377         clear(i, j, result.get());
378         setup_row_row(i, j, k, &params_row);
379         Gemm<Executor, ParamsRowMajor, 2, 4, 8>(params_row);
380         if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
381           std::cout << "Row: " << i << "x" << j << "x" << k << " : ERROR"
382                     << std::endl;
383           std::cout << "Exiting." << std::endl;
384           std::exit(1);
385         }
386 
387         clear(i, j, result.get());
388         setup_row_col(i, j, k, &params_col);
389         Gemm<Executor, ParamsColumnMajor, 2, 4, 8>(params_col);
390         if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
391           std::cout << "Column: " << i << "x" << j << "x" << k << " : ERROR"
392                     << std::endl;
393           std::cout << "Exiting." << std::endl;
394           std::exit(1);
395         }
396 
397         clear(i, j, result_f.get());
398         setup_row_row_f(i, j, k, &params_row_f);
399         Gemm<Executor, ParamsRowMajorAsFloat, 2, 4, 8>(params_row_f);
400         if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
401           std::cout << "RowAsFloat: " << i << "x" << j << "x" << k << " : ERROR"
402                     << std::endl;
403           std::cout << "Exiting." << std::endl;
404           std::exit(1);
405         }
406 
407         clear(i, j, result_f.get());
408         setup_row_col_f(i, j, k, &params_col_f);
409         Gemm<Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(params_col_f);
410         if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
411           std::cout << "ColumnAsFloat: " << i << "x" << j << "x" << k
412                     << " : ERROR" << std::endl;
413           std::cout << "Exiting." << std::endl;
414           std::exit(1);
415         }
416 
417         clear(i, j, result_i32.get());
418         setup_row_row_i32(i, j, k, &params_row_i32);
419         Gemm<Executor, ParamsRowMajorAsInt32, 2, 4, 8>(params_row_i32);
420         if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
421                                k)) {
422           std::cout << "RowAsInt32: " << i << "x" << j << "x" << k << " : ERROR"
423                     << std::endl;
424           std::cout << "Exiting." << std::endl;
425           std::exit(1);
426         }
427 
428         clear(i, j, result_i32.get());
429         setup_row_col_i32(i, j, k, &params_col_i32);
430         Gemm<Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(params_col_i32);
431         if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
432                                k)) {
433           std::cout << "ColumnAsInt32: " << i << "x" << j << "x" << k
434                     << " : ERROR" << std::endl;
435           std::cout << "Exiting." << std::endl;
436           std::exit(1);
437         }
438       }
439     }
440   }
441 
442   for (int i = 1; i < 1024; i += 211) {
443     for (int j = 1; j < 1024; j += 211) {
444       for (int k = 8; k < 1024; k += 111) {
445         prepare_test_data(lhs.get(), i, k, 11, 13);
446         prepare_test_data(rhs.get(), j, k, 13, 17);
447 
448         clear(i, j, result.get());
449         setup_row_row(i, j, k, &params_row);
450         MultiThreadGemm<Context, Executor, ParamsRowMajor, 2, 4, 8>(&context,
451                                                                     params_row);
452         if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
453           std::cout << "Row(MT): " << i << "x" << j << "x" << k << " : ERROR"
454                     << std::endl;
455           std::cout << "Exiting." << std::endl;
456           std::exit(1);
457         }
458 
459         clear(i, j, result.get());
460         setup_row_col(i, j, k, &params_col);
461         MultiThreadGemm<Context, Executor, ParamsColumnMajor, 2, 4, 8>(
462             &context, params_col);
463         if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
464           std::cout << "Column(MT): " << i << "x" << j << "x" << k << " : ERROR"
465                     << std::endl;
466           std::cout << "Exiting." << std::endl;
467           std::exit(1);
468         }
469 
470         clear(i, j, result_f.get());
471         setup_row_row_f(i, j, k, &params_row_f);
472         MultiThreadGemm<Context, Executor, ParamsRowMajorAsFloat, 2, 4, 8>(
473             &context, params_row_f);
474         if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
475           std::cout << "RowAsFloat(MT): " << i << "x" << j << "x" << k
476                     << " : ERROR" << std::endl;
477           std::cout << "Exiting." << std::endl;
478           std::exit(1);
479         }
480 
481         clear(i, j, result_f.get());
482         setup_row_col_f(i, j, k, &params_col_f);
483         MultiThreadGemm<Context, Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(
484             &context, params_col_f);
485         if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
486           std::cout << "ColumnAsFloat(MT): " << i << "x" << j << "x" << k
487                     << " : ERROR" << std::endl;
488           std::cout << "Exiting." << std::endl;
489           std::exit(1);
490         }
491 
492         clear(i, j, result_i32.get());
493         setup_row_row_i32(i, j, k, &params_row_i32);
494         MultiThreadGemm<Context, Executor, ParamsRowMajorAsInt32, 2, 4, 8>(
495             &context, params_row_i32);
496         if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
497                                k)) {
498           std::cout << "RowAsInt32(MT): " << i << "x" << j << "x" << k
499                     << " : ERROR" << std::endl;
500           std::cout << "Exiting." << std::endl;
501           std::exit(1);
502         }
503 
504         clear(i, j, result_i32.get());
505         setup_row_col_i32(i, j, k, &params_col_i32);
506         MultiThreadGemm<Context, Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(
507             &context, params_col_i32);
508         if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
509                                k)) {
510           std::cout << "ColumnAsInt32(MT): " << i << "x" << j << "x" << k
511                     << " : ERROR" << std::endl;
512           std::cout << "Exiting." << std::endl;
513           std::exit(1);
514         }
515       }
516     }
517   }
518 
519   std::cout << "OK." << std::endl;
520   return 0;
521 }
522