1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // multi_thread_common.h: Multithreading code shared by different meta gemm 16 // versions. 17 18 #ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_ 19 #define GEMMLOWP_META_MULTI_THREAD_COMMON_H_ 20 21 #include "../internal/multi_thread_gemm.h" 22 23 namespace gemmlowp { 24 namespace meta { 25 namespace internal { 26 27 const std::int32_t kMinTaskSize = 16000; 28 const std::int32_t kMinTaskDimension = 4; 29 30 struct TaskRect { 31 std::int32_t m_offset; 32 std::int32_t m; 33 std::int32_t n_offset; 34 std::int32_t n; 35 36 TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset, 37 std::int32_t n) 38 : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {} 39 }; 40 41 template <typename IN_TYPE, typename OUT_TYPE, typename F> 42 struct MetaTask : gemmlowp::Task { 43 std::uint8_t* scratch; 44 const IN_TYPE* lhs; 45 const IN_TYPE* rhs; 46 TaskRect task_rect; 47 std::int32_t k; 48 OUT_TYPE* result; 49 std::int32_t result_stride; 50 const F& operation; 51 52 MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs, 53 const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result, 54 std::int32_t result_stride, const F& operation) 55 : scratch(scratch), 56 lhs(lhs), 57 rhs(rhs), 58 task_rect(task_rect), 59 k(k), 60 result(result), 61 result_stride(result_stride), 62 operation(operation) {} 63 64 void Run() override { 65 const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k; 66 const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k; 67 OUT_TYPE* task_result = 68 result + task_rect.m_offset * result_stride + task_rect.n_offset; 69 operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m, 70 task_rect.n, k, task_result, result_stride); 71 } 72 }; 73 74 std::int32_t ResolveMaxThreads(std::int32_t max_threads) { 75 if (max_threads == 0) { 76 static const int hardware_threads_count = 77 static_cast<int>(sysconf(_SC_NPROCESSORS_CONF)); 78 return hardware_threads_count; 79 } 80 return max_threads; 81 } 82 83 void PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n, 84 std::int32_t k, std::vector<internal::TaskRect>* tasks) { 85 const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize; 86 const std::int32_t max_tasks_m = m / kMinTaskDimension; 87 const std::int32_t max_tasks_n = n / kMinTaskDimension; 88 const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n); 89 90 std::int32_t real_tasks = std::max( 91 1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension))); 92 93 if (real_tasks == 1) { 94 tasks->push_back(TaskRect(0, m, 0, n)); 95 return; 96 } 97 98 if (max_tasks_m > max_tasks_n) { 99 const std::int32_t m_chunk = m / real_tasks; 100 for (int i = 0; i < real_tasks - 1; ++i) { 101 tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n)); 102 } 103 const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk; 104 tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n)); 105 } else { 106 const std::int32_t n_chunk = n / real_tasks; 107 for (int i = 0; i < real_tasks - 1; ++i) { 108 tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk)); 109 } 110 const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk; 111 tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset)); 112 } 113 } 114 115 template <typename IN_TYPE, typename OUT_TYPE, typename F> 116 void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool, 117 std::int32_t max_threads, std::uint8_t* scratch, 118 const IN_TYPE* lhs, const IN_TYPE* rhs, 119 std::int32_t m, std::int32_t n, std::int32_t k, 120 OUT_TYPE* result, std::int32_t result_stride, 121 const F& operation) { 122 max_threads = internal::ResolveMaxThreads(max_threads); 123 124 std::vector<internal::TaskRect> task_rects; 125 internal::PrepareTasks(max_threads, m, n, k, &task_rects); 126 127 if (task_rects.size() == 1) { 128 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result, 129 result_stride); 130 return; 131 } 132 133 std::uint8_t* task_scratch = scratch; 134 std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k); 135 std::vector<Task*> tasks; 136 std::for_each( 137 task_rects.begin(), task_rects.end(), 138 [&tasks, &task_scratch, lhs, rhs, k, result, result_stride, operation, 139 scratch_per_thread](internal::TaskRect& rect) { 140 tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>( 141 task_scratch, lhs, rhs, rect, k, result, result_stride, operation)); 142 task_scratch += scratch_per_thread; 143 }); 144 pool->Execute(tasks); 145 } 146 147 } // namespace internal 148 } // namespace meta 149 } // namespace gemmlowp 150 151 #endif // GEMMLOWP_META_MULTI_THREAD_COMMON_H_ 152