1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // block_params.h: Logic to choose L1 and L2 block sizes
16 // to optimize cache-friendliness.
17
18 #ifndef GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
19 #define GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
20
21 #include "common.h"
22
23 namespace gemmlowp {
24
25 // A BlockParams instance contains a full description of all the block size
26 // parameters to be used by a Gemm.
27 // There are two nested levels of block subdivisions: first a subdivision
28 // into large blocks that should fit in last-level cache (what we call L2 here)
29 // and then another subdivision into smaller blocks that should fit in
30 // L1 cache. There is then actually a third level of subdivision to fit
31 // in registers, but we are not concerned with that here.
32 struct BlockParams {
33 // L1 block parameters determine the size of small blocks that should
34 // fit in L1 cache.
35 int l1_rows;
36 int l1_cols;
37 int l1_depth;
38
39 // L2 block parameters determine the size of larger blocks that should
40 // fit in L2 cache.
41 int l2_rows;
42 int l2_cols;
43 int l2_depth;
44
45 template <typename KernelFormat>
InitBlockParams46 void Init(int rows, int cols, int depth, int num_threads) {
47 FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads, &l2_rows,
48 &l2_cols, &l2_depth);
49 FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, &l1_rows,
50 &l1_cols, &l1_depth);
51 }
52
53 template <typename KernelFormat>
FindL2BlockSizesBlockParams54 static void FindL2BlockSizes(int rows, int cols, int depth, int num_threads,
55 int* out_l2_rows, int* out_l2_cols,
56 int* out_l2_depth) {
57 int l2_rows = 0;
58 int l2_cols = 0;
59 int l2_depth = 0;
60 // No L2 blocking in the depth dimension at the moment.
61 // Too much loss of accuracy due to storing intermediate results in
62 // low precision.
63 // However, we still want to round l2_depth up to the next multiple
64 // of register size, so as to avoid having to special-case unaligned depths.
65 l2_depth = RoundUp<kRegisterSize>(depth);
66
67 const int l2_bytes_to_use = kDefaultL2CacheSize;
68 const float l2_rhs_factor = kDefaultL2RhsFactor;
69
70 {
71 int max_cache_friendly_l2_cols = std::max(
72 1, static_cast<int>(l2_rhs_factor * (l2_bytes_to_use / l2_depth)));
73 int min_l2_cols_blocks =
74 std::max(1, CeilQuotient(cols, max_cache_friendly_l2_cols));
75 l2_cols =
76 RoundUp<KernelFormat::kCols>(CeilQuotient(cols, min_l2_cols_blocks));
77 }
78
79 // No L2 blocking in the row dimension if l2_rhs_factor is 1.0 as the row
80 // dimension concerns only the LHS. Blocking only RHS matrix for L2 enhances
81 // the performance on x86.
82 if (l2_rhs_factor == 1.0f) {
83 l2_rows = RoundUp<KernelFormat::kRows>(rows);
84 } else {
85 int max_cache_friendly_l2_rows =
86 std::max(1, (l2_bytes_to_use - l2_depth * l2_cols) /
87 (num_threads * (l2_depth + 4 * l2_cols)));
88 int min_l2_rows_blocks =
89 std::max(1, CeilQuotient(rows, max_cache_friendly_l2_rows));
90 l2_rows =
91 RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l2_rows_blocks));
92 }
93
94 *out_l2_rows = l2_rows;
95 *out_l2_cols = l2_cols;
96 *out_l2_depth = l2_depth;
97 }
98
99 template <typename KernelFormat>
FindL1BlockSizesBlockParams100 static void FindL1BlockSizes(int rows, int cols, int depth, int* out_l1_rows,
101 int* out_l1_cols, int* out_l1_depth) {
102 int l1_rows = 0;
103 int l1_cols = 0;
104 int l1_depth = 0;
105
106 // L2 block sizes should already be multiples of kernel block sizes.
107 assert(rows % KernelFormat::kRows == 0);
108 assert(cols % KernelFormat::kCols == 0);
109 assert(depth % KernelFormat::kDepth == 0);
110
111 // No L1 blocking in the columns dimension at the moment.
112 // Thought not to be needed. Similar to Eigen.
113 l1_cols = cols;
114
115 const int l1_bytes_to_use = kDefaultL1CacheSize;
116
117 {
118 int max_cache_friendly_l1_depth = std::max(
119 1, (l1_bytes_to_use - 4 * KernelFormat::kRows * KernelFormat::kCols) /
120 (KernelFormat::kRows + KernelFormat::kCols));
121 int min_l1_depth_blocks =
122 std::max(1, CeilQuotient(depth, max_cache_friendly_l1_depth));
123 l1_depth =
124 RoundUp<kRegisterSize>(CeilQuotient(depth, min_l1_depth_blocks));
125 }
126
127 {
128 int max_cache_friendly_l1_rows =
129 std::max(1, l1_bytes_to_use / (l1_depth + 4 * l1_cols));
130 int min_l1_rows_blocks =
131 std::max(1, CeilQuotient(rows, max_cache_friendly_l1_rows));
132 l1_rows =
133 RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l1_rows_blocks));
134 }
135
136 *out_l1_rows = l1_rows;
137 *out_l1_cols = l1_cols;
138 *out_l1_depth = l1_depth;
139 }
140 };
141
142 // A SideBlockParams instance contains only the block params relevant to
143 // one side (LHS or RHS), expressed in terms of 'width' instead of
144 // rows/colums. See the explanation in kernel.h: in the LHS, 'width' means
145 // the number of rows, while in the RHS, 'width' means the number of columns.
146 // That allows us to write generic code that applies to either LHS or RHS.
147 struct SideBlockParams {
148 // L1 block parameters determine the size of small blocks that should
149 // fit in L1 cache.
150 int l1_width;
151 int l1_depth;
152
153 // L2 block parameters determine the size of larger blocks that should
154 // fit in L2 cache.
155 int l2_width;
156 int l2_depth;
157 };
158
159 enum class Side { Lhs, Rhs };
160
GetSideBlockParams(Side side,SideBlockParams * side_block_params,const BlockParams & block_params)161 inline void GetSideBlockParams(Side side, SideBlockParams* side_block_params,
162 const BlockParams& block_params) {
163 side_block_params->l1_width =
164 side == Side::Lhs ? block_params.l1_rows : block_params.l1_cols;
165 side_block_params->l2_width =
166 side == Side::Lhs ? block_params.l2_rows : block_params.l2_cols;
167
168 side_block_params->l1_depth = block_params.l1_depth;
169 side_block_params->l2_depth = block_params.l2_depth;
170 }
171
172 } // namespace gemmlowp
173
174 #endif // GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_
175