1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // kernel.h: general definitions for kernels.
16 
17 #ifndef GEMMLOWP_INTERNAL_KERNEL_H_
18 #define GEMMLOWP_INTERNAL_KERNEL_H_
19 
20 #include "../public/bit_depth.h"
21 #include "common.h"
22 
23 namespace gemmlowp {
24 
25 // Explanation of general gemmlowp terminology
26 // ===========================================
27 //
28 // We use the following abbreviations:
29 // LHS = "left-hand side"
30 // RHS = "right-hand side"
31 // Sometimes when referring to either LHS or RHS, we just say a "Side".
32 //
33 // In a matrix product of a MxK matrix times a KxN matrix,
34 // we call K the 'depth'. Note that M is the number of rows
35 // of the result (and of the LHS), and N is the number of columns
36 // of the result (and of the RHS).
37 //
38 // In each of the LHS and RHS matrices, we call 'width' the
39 // other dimension, besides the depth. So in the LHS, 'width'
40 // is the number of rows, while in the RHS, 'width' is the number
41 // of columns.
42 //
43 //  So in the LHS MxK matrix, the depth is K and the width in M.
44 // And in the RHS KxN matrix, the depth is K and the width in N.
45 //
46 // This is illustrated in this picture:
47 //
48 //                             RHS width
49 //                        <----------------->
50 //                        +-----------------+ ^
51 //                        |       RHS       | | Depth
52 //                        +-----------------+ v
53 //                 ^ +--+ +-----------------+
54 //                 | |L | |                 |
55 //       LHS width | |H | |      Result     |
56 //                 | |S | |                 |
57 //                 v +--+ +-----------------+
58 //                   <-->
59 //                   Depth
60 
61 // Explanation of gemmlowp kernel formats and "cells"
62 // ==================================================
63 //
64 // Kernels operate on small LHS and RHS blocks that fit in registers.
65 // These blocks are stored contiguously in memory, but not always
66 // in a traditional column-major or row-major order; instead,
67 // they consist of a number of sub-blocks, which we call "cells",
68 // that are stored in column-major or row-major order. However,
69 // what really matters to us is not so much rows vs columns, but
70 // rather width vs depth. So we refer to "width-major" and "depth-major"
71 // storage orders. In the LHS, width-major means row-major,
72 // while in the RHS, width-major means column-major.
73 // There is also a third possibility, "diagonal order",
74 // which is unused at the moment.
75 //
76 // We aim to treat both sides, LHS and RHS, on an equal footing,
77 // so we call them both 'sides'. A KernelFormat thus is just a pair
78 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
79 // contains a CellFormat and a number of cells; cells are only ever
80 // stacked in the width dimension, which means stacked vertically in the
81 // LHS and stacked horizondally in the RHS.
82 //
83 // Example
84 // =======
85 //
86 // Let's work out the data layout expected by a kernel having the
87 // following format (the struct names here are defined below in this file):
88 //
89 // KernelFormat<
90 //   KernelSideFormat<CellFormat<3, 4>, 3>,
91 //   KernelSideFormat<CellFormat<5, 4>, 2>
92 // >
93 //
94 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
95 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in
96 // DepthMajor order (the default value, see CellFormat). In the LHS,
97 // DepthMajor means column-major, so the LHS cells are of size 3x4 in
98 // column-major order, so the LHS layout is:
99 //
100 // 0  3  6  9
101 // 1  4  7  10
102 // 2  5  8  11
103 // 12 15 18 21
104 // 13 16 19 22
105 // 14 17 20 23
106 // 24 27 30 33
107 // 25 28 31 34
108 // 26 29 32 35
109 //
110 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
111 // 2 cells each having dimensions (width=5, depth=4), laid out in
112 // DepthMajor order (the default value, see CellFormat). In the RHS,
113 // DepthMajor means row-major, so the RHS cells are of size 4x5 in
114 // row-major order, so the RHS layout is:
115 //
116 // 0  1  2  3  4  20 21 22 23 24
117 // 5  6  7  8  9  25 26 27 28 29
118 // 10 11 12 13 14 30 31 32 33 34
119 // 15 16 17 18 19 35 36 37 38 39
120 
121 // CellOrder enumerates the possible storage orders (=layouts) for
122 // a cell (see explanation above).
123 enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
124 
125 // CellFormat describes how data is laid
126 // out in a cell. That is, a CellOrder together with actual dimensions.
127 template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
128 struct CellFormat {
129   static constexpr int kWidth = tWidth;
130   static constexpr int kDepth = tDepth;
131   static constexpr CellOrder kOrder = tOrder;
132 
133   static constexpr int kSize = kWidth * kDepth;
134 };
135 
136 // KernelSideFormat describes how data is laid out in a kernel side
137 // (i.e. LHS or RHS). That is, a CellFormat together with a number of
138 // cells. These cells are always stacked in the Width dimension.
139 // For example, in the LHS case, the Width dimension is the rows dimension,
140 // se we're saying that in the LHS, cells are stacked vertically.
141 // We never stack cells in the Depth dimension.
142 template <typename tCellFormat, int tCells>
143 struct KernelSideFormat {
144   typedef tCellFormat Cell;
145   static constexpr int kCells = tCells;
146   static constexpr int kWidth = kCells * Cell::kWidth;
147   static constexpr int kDepth = Cell::kDepth;
148   typedef std::uint8_t Scalar;       // The scalar type of the Format.
149   typedef std::uint8_t InputScalar;  // The scalar type of the original input.
150 };
151 
152 // KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
153 // packs converts it to int8.
154 template <typename tCellFormat, int tCells>
155 struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
156   typedef std::int8_t Scalar;
157   typedef std::uint8_t InputScalar;
158 };
159 
160 // KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
161 // pack conversion.
162 template <typename tCellFormat, int tCells>
163 struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
164   typedef std::int8_t Scalar;
165   typedef std::int8_t InputScalar;
166 };
167 
168 // KernelFormat describes fully the input data layout that a kernel expects.
169 // It consists of two KernelSideFormat's, one for LHS and one for RHS.
170 template <typename tLhs, typename tRhs>
171 struct KernelFormat {
172   typedef tLhs Lhs;
173   typedef tRhs Rhs;
174 
175   static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
176   static constexpr int kDepth = Lhs::Cell::kDepth;
177   static constexpr int kRows = Lhs::Cell::kWidth * Lhs::kCells;
178   static constexpr int kCols = Rhs::Cell::kWidth * Rhs::kCells;
179 };
180 
CellOrderName(CellOrder o)181 inline const char* CellOrderName(CellOrder o) {
182   switch (o) {
183     case CellOrder::DepthMajor:
184       return "DepthMajor";
185     case CellOrder::WidthMajor:
186       return "WidthMajor";
187     case CellOrder::Diagonal:
188       return "Diagonal";
189     default:
190       assert(false);
191       return nullptr;
192   }
193 }
194 
195 // Returns the offset into a cell, at which a given coefficient is stored.
196 template <typename CellFormat>
OffsetIntoCell(int w,int d)197 inline int OffsetIntoCell(int w, int d) {
198   const int size = CellFormat::kWidth;
199   switch (CellFormat::kOrder) {
200     case CellOrder::DepthMajor:
201       return w + d * CellFormat::kWidth;
202     case CellOrder::WidthMajor:
203       return d + w * CellFormat::kDepth;
204     case CellOrder::Diagonal:
205       assert(CellFormat::kWidth == CellFormat::kDepth);
206       return ((size + w - d) * size + d) % (size * size);
207     default:
208       assert(false);
209       return 0;
210   }
211 }
212 
213 // KernelBase is the virtual base class below all kernels.
214 // The idea is that we don't need to templatize all our code on the exact
215 // kernel type; we only need to templatize on kernel format. Kernels
216 // sharing the same format can thus share the same packing/unpacking code.
217 struct KernelBase {
218   virtual const char* Name() const = 0;
219 
220   // This is the kernel implementation. We use the word 'run' consistently
221   // throughout gemmlowp to mean an inner loop, the implementation of which
222   // is to be provided by a separate optimized function.
223   virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
224                    std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
225                    const std::uint8_t* rhs_ptr, std::size_t start_depth,
226                    std::size_t run_depth) const = 0;
227 
~KernelBaseKernelBase228   virtual ~KernelBase() {}
229 };
230 
231 template <typename InputKernelScalarType, typename KernelScalarType>
232 struct ZeroPointInputValue {};
233 
234 template <>
235 struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
236   static constexpr std::uint8_t kValue = 0;
237 };
238 
239 template <>
240 struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
241   static constexpr std::uint8_t kValue = 128;
242 };
243 
244 template <>
245 struct ZeroPointInputValue<std::int8_t, std::int8_t> {
246   static constexpr std::uint8_t kValue = 0;
247 };
248 
249 }  // namespace gemmlowp
250 
251 #endif  // GEMMLOWP_INTERNAL_KERNEL_H_
252