1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // multi_thread_common.h: Multithreading code shared by different meta gemm
16 // versions.
17
18 #ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_
19 #define GEMMLOWP_META_MULTI_THREAD_COMMON_H_
20
21 #include "../internal/multi_thread_gemm.h"
22
23 namespace gemmlowp {
24 namespace meta {
25 namespace internal {
26
27 const std::int32_t kMinTaskSize = 10000;
28 const std::int32_t kMinTaskDimension = 6;
29
30 struct TaskRect {
31 std::int32_t m_offset;
32 std::int32_t m;
33 std::int32_t n_offset;
34 std::int32_t n;
35
TaskRectTaskRect36 TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset,
37 std::int32_t n)
38 : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {}
39 };
40
41 template <typename IN_TYPE, typename OUT_TYPE, typename F>
42 struct MetaTask : gemmlowp::Task {
43 std::uint8_t* scratch;
44 const IN_TYPE* lhs;
45 const IN_TYPE* rhs;
46 TaskRect task_rect;
47 std::int32_t k;
48 OUT_TYPE* result;
49 std::int32_t result_stride;
50 const F& operation;
51
MetaTaskMetaTask52 MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs,
53 const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result,
54 std::int32_t result_stride, const F& operation)
55 : scratch(scratch),
56 lhs(lhs),
57 rhs(rhs),
58 task_rect(task_rect),
59 k(k),
60 result(result),
61 result_stride(result_stride),
62 operation(operation) {}
63
RunMetaTask64 void Run() const override {
65 const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k;
66 const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k;
67 OUT_TYPE* task_result =
68 result + task_rect.m_offset * result_stride + task_rect.n_offset;
69 operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m,
70 task_rect.n, k, task_result, result_stride);
71 }
72 };
73
ResolveMaxThreads(std::int32_t max_threads)74 std::int32_t ResolveMaxThreads(std::int32_t max_threads) {
75 if (max_threads == 0) {
76 static const int hardware_threads_count =
77 static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
78 return hardware_threads_count;
79 }
80 return max_threads;
81 }
82
PrepareTasks(std::int32_t max_tasks,std::int32_t m,std::int32_t n,std::int32_t k,std::vector<internal::TaskRect> * tasks)83 void PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n,
84 std::int32_t k, std::vector<internal::TaskRect>* tasks) {
85 const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize;
86 const std::int32_t max_tasks_m = m / kMinTaskDimension;
87 const std::int32_t max_tasks_n = n / kMinTaskDimension;
88 const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
89
90 std::int32_t real_tasks = std::max(
91 1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension)));
92
93 if (real_tasks == 1) {
94 tasks->push_back(TaskRect(0, m, 0, n));
95 return;
96 }
97
98 if (max_tasks_m > max_tasks_n) {
99 const std::int32_t m_chunk = m / real_tasks;
100 for (int i = 0; i < real_tasks - 1; ++i) {
101 tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n));
102 }
103 const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk;
104 tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n));
105 } else {
106 const std::int32_t n_chunk = n / real_tasks;
107 for (int i = 0; i < real_tasks - 1; ++i) {
108 tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk));
109 }
110 const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk;
111 tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset));
112 }
113 }
114
115 template <typename IN_TYPE, typename OUT_TYPE, typename F>
MultiThreadedMatrixMatrix(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const IN_TYPE * lhs,const IN_TYPE * rhs,std::int32_t m,std::int32_t n,std::int32_t k,OUT_TYPE * result,std::int32_t result_stride,const F & operation)116 void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
117 std::int32_t max_threads, std::uint8_t* scratch,
118 const IN_TYPE* lhs, const IN_TYPE* rhs,
119 std::int32_t m, std::int32_t n, std::int32_t k,
120 OUT_TYPE* result, std::int32_t result_stride,
121 const F& operation) {
122 max_threads = internal::ResolveMaxThreads(max_threads);
123 if (max_threads > 1) {
124 pool->CreateWorkers(max_threads - 1);
125 }
126
127 std::vector<internal::TaskRect> task_rects;
128 internal::PrepareTasks(max_threads, m, n, k, &task_rects);
129
130 if (task_rects.size() == 1) {
131 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result,
132 result_stride);
133 return;
134 }
135
136 std::uint8_t* task_scratch = scratch;
137 std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
138 std::int32_t worker_tasks = task_rects.size() - 1;
139 pool->counter_to_decrement_when_ready().Reset(worker_tasks);
140
141 for (std::int32_t i = 0; i < worker_tasks; ++i) {
142 auto task = new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
143 task_scratch, lhs, rhs, task_rects[i], k, result, result_stride,
144 operation);
145 pool->StartWorker(i, task);
146 task_scratch += scratch_per_thread;
147 }
148
149 {
150 internal::MetaTask<IN_TYPE, OUT_TYPE, F> master_task(
151 task_scratch, lhs, rhs, task_rects.back(), k, result, result_stride,
152 operation);
153 master_task.Run();
154 }
155
156 pool->counter_to_decrement_when_ready().Wait();
157 }
158
159 } // namespace internal
160 } // namespace meta
161 } // namespace gemmlowp
162
163 #endif // GEMMLOWP_META_MULTI_THREAD_COMMON_H_
164