1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
16 #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
17 
18 #include "multi_thread_common.h"
19 #include "single_thread_gemm.h"
20 
21 namespace gemmlowp {
22 namespace meta {
23 namespace internal {
24 
25 const std::int32_t kMinGemmTaskSize = 16000;
26 const std::int32_t kMinGemmTaskDimension = 4;
27 
28 template <typename Executor, typename Params>
PrepareGemmTask(const Params & params,int kernel_m,int kernel_n,int kernel_k,std::uint8_t * scratch,int m_start,int m,int n_start,int n,std::vector<Params> * tasks)29 std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n,
30                               int kernel_k, std::uint8_t* scratch, int m_start,
31                               int m, int n_start, int n,
32                               std::vector<Params>* tasks) {
33   tasks->push_back(params);
34   Params& task = tasks->back();
35   task.scratch = scratch;
36 
37   task.m = m;
38   task.lhs =
39       StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
40           params.left_stream, params.lhs, m_start, 0);
41 
42   task.n = n;
43   task.rhs =
44       StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
45           params.right_stream, params.rhs, n_start, 0);
46 
47   task.result =
48       StreamUtil<typename Params::OutType, typename Params::OutputStream>::
49           Offset(params.fused_kernel.output_stream, params.result, m_start,
50                  n_start);
51 
52   return scratch + Executor::template EstimateScratchSize<Params>(
53                        task, kernel_m, kernel_n, kernel_k);
54 }
55 
56 template <typename MultiThreadingContext, typename Executor, typename Params>
PrepareGemmTasks(MultiThreadingContext * context,const Params & params,int kernel_m,int kernel_n,int kernel_k,std::vector<Params> * task_params)57 bool PrepareGemmTasks(MultiThreadingContext* context, const Params& params,
58                       int kernel_m, int kernel_n, int kernel_k,
59                       std::vector<Params>* task_params) {
60   const int max_threads = ResolveMaxThreads(context->max_num_threads());
61   const int max_tasks_by_size =
62       (params.m * params.n * params.k) / kMinGemmTaskSize;
63   const int max_tasks_m = params.m / kMinGemmTaskDimension;
64   const int max_tasks_n = params.n / kMinGemmTaskDimension;
65   const int max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
66 
67   const int real_tasks = std::max(
68       1,
69       std::min(max_threads, std::min(max_tasks_by_size, max_tasks_dimension)));
70 
71   if (real_tasks == 1) {
72     return false;
73   }
74 
75   std::uint8_t* scratch = params.scratch;
76 
77   if (max_tasks_m > max_tasks_n) {
78     const int m_chunk = params.m / real_tasks;
79     for (int i = 0; i < real_tasks - 1; ++i) {
80       scratch = PrepareGemmTask<Executor, Params>(
81           params, kernel_m, kernel_n, kernel_k, scratch, i * m_chunk, m_chunk,
82           0, params.n, task_params);
83     }
84     const int sum_m = (real_tasks - 1) * m_chunk;
85     PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
86                                       scratch, sum_m, params.m - sum_m, 0,
87                                       params.n, task_params);
88   } else {
89     const int n_chunk = params.n / real_tasks;
90     for (int i = 0; i < real_tasks - 1; ++i) {
91       scratch = PrepareGemmTask<Executor, Params>(
92           params, kernel_m, kernel_n, kernel_k, scratch, 0, params.m,
93           i * n_chunk, n_chunk, task_params);
94     }
95     int sum_n = (real_tasks - 1) * n_chunk;
96     PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k,
97                                       scratch, 0, params.m, sum_n,
98                                       params.n - sum_n, task_params);
99   }
100 
101   return true;
102 }
103 
104 template <typename Executor, typename Params, int kernel_m, int kernel_n,
105           int kernel_k>
106 struct GemmTaskRunner : gemmlowp::Task {
GemmTaskRunnerGemmTaskRunner107   GemmTaskRunner(const Params& params) : params(params) {}
108 
RunGemmTaskRunner109   void Run() override {
110     Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
111   }
112 
113   Params params;
114 };
115 
116 }  // namespace internal
117 
118 template <typename MultiThreadingContext, typename Executor, typename Params,
119           int kernel_m, int kernel_n, int kernel_k>
MultiThreadGemm(MultiThreadingContext * context,const Params & params)120 inline void MultiThreadGemm(MultiThreadingContext* context,
121                             const Params& params) {
122   typedef internal::GemmTaskRunner<Executor, Params, kernel_m, kernel_n,
123                                    kernel_k>
124       TaskRunnerType;
125 
126   std::vector<Params> task_params;
127   if (!internal::PrepareGemmTasks<MultiThreadingContext, Executor, Params>(
128           context, params, kernel_m, kernel_n, kernel_k, &task_params)) {
129     Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params);
130     return;
131   }
132 
133   auto workers_pool = context->workers_pool();
134   std::vector<Task*> tasks;
135   for (auto& task_param : task_params) {
136     tasks.push_back(new TaskRunnerType(task_param));
137   };
138   workers_pool->Execute(tasks);
139 }
140 
141 }  // namespace meta
142 }  // namespace gemmlowp
143 
144 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
145