1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // unpack.h: unpacking the result blocks computed by compute.h,
16 // storing them into the destination matrix.
17 
18 #ifndef GEMMLOWP_INTERNAL_UNPACK_H_
19 #define GEMMLOWP_INTERNAL_UNPACK_H_
20 
21 #include "allocator.h"
22 #include "block_params.h"
23 #include "output.h"
24 #include "pack.h"
25 
26 #include <cmath>
27 
28 namespace gemmlowp {
29 
30 class PackedResult {
31  public:
PackedResult(Allocator * _allocator,const BlockParams & _block_params)32   PackedResult(Allocator* _allocator, const BlockParams& _block_params)
33       : allocator_(_allocator), block_params_(_block_params) {
34     matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows *
35                                                        block_params_.l2_cols);
36   }
37 
~PackedResult()38   ~PackedResult() {}
39 
Map()40   MatrixMap<std::int32_t, MapOrder::ColMajor> Map() {
41     return MatrixMap<std::int32_t, MapOrder::ColMajor>(
42         allocator_->GetPointer<std::int32_t>(matrix_handle_),
43         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
44   }
45 
Map()46   MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const {
47     return MatrixMap<const std::int32_t, MapOrder::ColMajor>(
48         allocator_->GetPointer<const std::int32_t>(matrix_handle_),
49         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
50   }
51 
52  private:
53   Allocator* allocator_;
54   Allocator::Handle matrix_handle_;
55   const BlockParams& block_params_;
56 };
57 
58 struct MatrixBlockBounds {
59   int start_row;
60   int start_col;
61   int rows;
62   int cols;
63 
MatrixBlockBoundsMatrixBlockBounds64   MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_)
65       : start_row(start_row_),
66         start_col(start_col_),
67         rows(rows_),
68         cols(cols_) {}
69 };
70 
71 template <int Rows, int Cols, typename SrcMapType>
PrefetchResultBlock(const SrcMapType & src,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,int src_row,int src_col)72 void PrefetchResultBlock(const SrcMapType& src,
73                          const VectorMap<const std::int32_t, VectorShape::Col>&
74                              lhs_sums_of_each_slice,
75                          int src_row, int src_col) {
76   const std::int32_t* src_data = src.data(src_row, src_col);
77   const int src_stride = src.stride();
78   const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row);
79   for (int r = 0; r < Rows; r += 4) {
80     Prefetch(lhs_sums_data + r);
81   }
82   for (int c = 0; c < Cols; c++) {
83     for (int r = 0; r < Rows; r += 4) {
84       Prefetch(src_data + r + c * src_stride);
85     }
86   }
87 }
88 
89 template <typename KernelFormat, typename RegisterBlockType,
90           typename SrcMapType, typename LhsOffset, typename RhsOffset,
91           typename OutputPipelineExecutorType, typename DstType>
UnpackResultBlock(const SrcMapType & src,const OutputPipelineExecutorType & executor,DstType * dst,const VectorMap<const std::int32_t,VectorShape::Col> & lhs_sums_of_each_slice,const VectorMap<const std::int32_t,VectorShape::Row> & rhs_sums_of_each_slice,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,int depth,int src_row,int src_col,int src_global_row,int src_global_col,int dst_row,int dst_col)92 void UnpackResultBlock(const SrcMapType& src,
93                        const OutputPipelineExecutorType& executor, DstType* dst,
94                        const VectorMap<const std::int32_t, VectorShape::Col>&
95                            lhs_sums_of_each_slice,
96                        const VectorMap<const std::int32_t, VectorShape::Row>&
97                            rhs_sums_of_each_slice,
98                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
99                        int depth, int src_row, int src_col, int src_global_row,
100                        int src_global_col, int dst_row, int dst_col) {
101   using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
102   using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
103   static constexpr int KernelLhsZeroPointInput =
104       ZeroPointInputValue<KernelLhsScalar>::kValue;
105   static constexpr int KernelRhsZeroPointInput =
106       ZeroPointInputValue<KernelRhsScalar>::kValue;
107   auto acc = Load<RegisterBlockType>(src, src_row, src_col);
108   const auto& lhs_sums_of_each_slice_block =
109       LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
110   const auto& rhs_sums_of_each_slice_block =
111       LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col);
112   auto lhs_offset_block =
113       LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row);
114   auto rhs_offset_block =
115       LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col);
116   AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block);
117   AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block);
118   BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc);
119   for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) {
120     rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth);
121   }
122   BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block),
123                   lhs_offset_block, &acc);
124   executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col);
125 }
126 
127 template <typename KernelFormat, typename ResultBlockType,
128           typename PackedResultType, typename LhsOffset, typename RhsOffset,
129           typename OutputPipelineType>
UnpackResult(ResultBlockType * dst,const MatrixBlockBounds & dst_block,const PackedResultType & src,int depth,const std::int32_t * lhs_sums_of_each_slice_ptr,const std::int32_t * rhs_sums_of_each_slice_ptr,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)130 void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
131                   const PackedResultType& src, int depth,
132                   const std::int32_t* lhs_sums_of_each_slice_ptr,
133                   const std::int32_t* rhs_sums_of_each_slice_ptr,
134                   const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
135                   const OutputPipelineType& output_pipeline) {
136   ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor
137                                  ? "unpack to column-major"
138                                  : "unpack to row-major");
139   assert(dst_block.start_row >= 0);
140   assert(dst_block.start_row + dst_block.rows <= dst->rows());
141   assert(dst_block.start_col >= 0);
142   assert(dst_block.start_col + dst_block.cols <= dst->cols());
143   const auto src_map = src.Map();
144   const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice(
145       lhs_sums_of_each_slice_ptr, dst_block.rows);
146   const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice(
147       rhs_sums_of_each_slice_ptr, dst_block.cols);
148   using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>;
149   using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>;
150   using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>;
151   using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>;
152   using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>;
153   using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>;
154 
155   using DstScalarType = typename ResultBlockType::Scalar;
156   using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>;
157 
158   OutputPipelineExecutor<OutputPipelineType, Int32x1x1>
159       output_pipeline_executor_1x1(output_pipeline);
160   OutputPipelineExecutor<OutputPipelineType, Int32x4x1>
161       output_pipeline_executor_4x1(output_pipeline);
162   OutputPipelineExecutor<OutputPipelineType, Int32x8x1>
163       output_pipeline_executor_8x1(output_pipeline);
164   OutputPipelineExecutor<OutputPipelineType, Int32x1x4>
165       output_pipeline_executor_1x4(output_pipeline);
166   OutputPipelineExecutor<OutputPipelineType, Int32x4x4>
167       output_pipeline_executor_4x4(output_pipeline);
168   OutputPipelineExecutor<OutputPipelineType, Int32x8x4>
169       output_pipeline_executor_8x4(output_pipeline);
170 
171   int c8 = 0;
172   if (ResultBlockType::kOrder == MapOrder::RowMajor) {
173     for (; c8 <= dst_block.cols - 8; c8 += 8) {
174       PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8);
175       int r = 0;
176       for (; r <= dst_block.rows - 8; r += 8) {
177         const int global_row = r + dst_block.start_row;
178         PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8);
179         DstScalarType dst_colmajor_buf[64];
180         MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map(
181             dst_colmajor_buf, 8, 8);
182         for (int cx = 0; cx < 8; cx += 4) {
183           const int c = c8 + cx;
184           const int global_col = c + dst_block.start_col;
185           UnpackResultBlock<KernelFormat, Int32x8x4>(
186               src_map, output_pipeline_executor_8x4, &dst_colmajor_map,
187               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
188               rhs_offset, depth, r, c, global_row, global_col, 0, cx);
189         }
190         StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst,
191                          r + dst_block.start_row, c8 + dst_block.start_col);
192       }
193       for (; r <= dst_block.rows - 4; r += 4) {
194         const int global_row = r + dst_block.start_row;
195         for (int cx = 0; cx < 8; cx += 4) {
196           const int c = c8 + cx;
197           const int global_col = c + dst_block.start_col;
198           UnpackResultBlock<KernelFormat, Int32x4x4>(
199               src_map, output_pipeline_executor_4x4, dst,
200               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
201               rhs_offset, depth, r, c, global_row, global_col, global_row,
202               global_col);
203         }
204       }
205       for (; r < dst_block.rows; r++) {
206         const int global_row = r + dst_block.start_row;
207         for (int cx = 0; cx < 8; cx += 4) {
208           const int c = c8 + cx;
209           const int global_col = c + dst_block.start_col;
210           UnpackResultBlock<KernelFormat, Int32x1x4>(
211               src_map, output_pipeline_executor_1x4, dst,
212               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
213               rhs_offset, depth, r, c, global_row, global_col, global_row,
214               global_col);
215         }
216       }
217     }
218   }
219   int c = c8;
220   for (; c <= dst_block.cols - 4; c += 4) {
221     const int global_col = c + dst_block.start_col;
222     PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c);
223     int r = 0;
224     for (; r <= dst_block.rows - 8; r += 8) {
225       const int global_row = r + dst_block.start_row;
226       PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c);
227       UnpackResultBlock<KernelFormat, Int32x8x4>(
228           src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice,
229           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
230           global_row, global_col, global_row, global_col);
231     }
232     for (; r <= dst_block.rows - 4; r += 4) {
233       const int global_row = r + dst_block.start_row;
234       UnpackResultBlock<KernelFormat, Int32x4x4>(
235           src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice,
236           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
237           global_row, global_col, global_row, global_col);
238     }
239     for (; r < dst_block.rows; r++) {
240       const int global_row = r + dst_block.start_row;
241       UnpackResultBlock<KernelFormat, Int32x1x4>(
242           src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice,
243           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
244           global_row, global_col, global_row, global_col);
245     }
246   }
247   for (; c < dst_block.cols; c++) {
248     const int global_col = c + dst_block.start_col;
249     PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c);
250     int r = 0;
251     for (; r <= dst_block.rows - 8; r += 8) {
252       const int global_row = r + dst_block.start_row;
253       PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c);
254       UnpackResultBlock<KernelFormat, Int32x8x1>(
255           src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice,
256           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
257           global_row, global_col, global_row, global_col);
258     }
259     for (; r <= dst_block.rows - 4; r += 4) {
260       const int global_row = r + dst_block.start_row;
261       UnpackResultBlock<KernelFormat, Int32x4x1>(
262           src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice,
263           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
264           global_row, global_col, global_row, global_col);
265     }
266     for (; r < dst_block.rows; r++) {
267       const int global_row = r + dst_block.start_row;
268       UnpackResultBlock<KernelFormat, Int32x1x1>(
269           src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice,
270           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
271           global_row, global_col, global_row, global_col);
272     }
273   }
274 }
275 
276 }  // end namespace gemmlowp
277 
278 #endif  // GEMMLOWP_INTERNAL_UNPACK_H_
279