1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // single_thread_gemm.h: Single-threaded GEMM implementation.
16 // This is a good place to start reading code, as it shows the overall
17 // structure of a GEMM and is much simpler than multi_thread_gemm.h.
18
19 #ifndef GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
20 #define GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
21
22 #include <cassert>
23
24 #include "../public/map.h"
25 #include "allocator.h"
26 #include "compute.h"
27 #include "kernel.h"
28 #include "pack.h"
29 #include "unpack.h"
30
31 namespace gemmlowp {
32
33 class SingleThreadGemmContext {
34 public:
allocator()35 Allocator* allocator() { return &allocator_; }
36
37 protected:
38 Allocator allocator_;
39 };
40
41 typedef VectorMap<const int32_t, VectorShape::Col> OffsetColMap;
42 typedef VectorMap<const int32_t, VectorShape::Row> OffsetRowMap;
43 typedef VectorDup<const int32_t, VectorShape::Col> OffsetColDup;
44 typedef VectorDup<const int32_t, VectorShape::Row> OffsetRowDup;
45
46 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
47 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
48 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
49 typename OutputPipelineType>
SingleThreadGemm(SingleThreadGemmContext * context,const KernelBase & kernel,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)50 void SingleThreadGemm(SingleThreadGemmContext* context,
51 const KernelBase& kernel,
52 const MatrixMap<const InputScalar, LhsOrder>& lhs,
53 const MatrixMap<const InputScalar, RhsOrder>& rhs,
54 MatrixMap<OutputScalar, ResultOrder>* result,
55 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
56 const OutputPipelineType& output_pipeline) {
57 ScopedProfilingLabel label("gemmlowp::SingleThreadGemm");
58
59 assert(lhs.cols() == rhs.rows());
60
61 int rows = result->rows();
62 int cols = result->cols();
63 int depth = lhs.cols();
64
65 assert(rows > 0);
66 assert(cols > 0);
67 assert(depth > 0);
68
69 Allocator* allocator = context->allocator();
70
71 BlockParams block_params;
72 block_params.Init<KernelFormat>(rows, cols, depth, 1);
73
74 PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(
75 Side::Lhs, allocator, block_params);
76 PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
77 Side::Rhs, allocator, block_params);
78
79 PackedResult packed_result(allocator, block_params);
80
81 allocator->Commit();
82
83 const bool pack_rhs_once = block_params.l2_cols == cols;
84
85 if (pack_rhs_once) {
86 PackRhs<BitDepthParams>(&packed_rhs, rhs);
87 }
88
89 for (int r = 0; r < rows; r += block_params.l2_rows) {
90 int rs = std::min(block_params.l2_rows, rows - r);
91
92 PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
93
94 for (int c = 0; c < cols; c += block_params.l2_cols) {
95 int cs = std::min(block_params.l2_cols, cols - c);
96
97 if (!pack_rhs_once) {
98 PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
99 }
100
101 Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
102
103 auto result_block = result->block(r, c, rs, cs);
104 UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
105 packed_lhs.sums_of_each_slice(),
106 packed_rhs.sums_of_each_slice(),
107 lhs_offset, rhs_offset, output_pipeline);
108 }
109 }
110
111 allocator->Decommit();
112 }
113
114 } // namespace gemmlowp
115
116 #endif // GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
117