1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // gemmlowp.h: the main public interface header of gemmlowp.
16
17 #ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
18 #define GEMMLOWP_PUBLIC_GEMMLOWP_H_
19 #include "../internal/kernel_default.h"
20 #include "../internal/multi_thread_gemm.h"
21 #include "../internal/unpack.h"
22 #include "bit_depth.h"
23 #include "map.h"
24 #include "output_stages.h"
25
26 namespace gemmlowp {
27
IsRequantizationWorthIt(int rows,int cols)28 inline bool IsRequantizationWorthIt(int rows, int cols) {
29 // We pack depth*(rows+cols) and compute depth*rows*cols.
30 // Thus the ratio of compute/packing cost is rows*cols/(rows+cols)
31 // In the square case rows==cols==N, it becomes N/2.
32 return 2 * rows * cols >= (rows + cols) * kMinimumWidthForRequantization;
33 }
34
35 class GemmContext : public MultiThreadGemmContext {};
36
37 // Computes a general matrix product ("GEMM").
38 // This is a version that supports per channel quantization.
39 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
40 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
41 typename LhsOffset, typename RhsOffset, typename OutputPipelineType>
GemmWithOutputPipelinePC(GemmContext * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)42 void GemmWithOutputPipelinePC(GemmContext* context,
43 const MatrixMap<const InputScalar, LhsOrder>& lhs,
44 const MatrixMap<const InputScalar, RhsOrder>& rhs,
45 MatrixMap<OutputScalar, ResultOrder>* result,
46 const LhsOffset& lhs_offset,
47 const RhsOffset& rhs_offset,
48 const OutputPipelineType& output_pipeline) {
49 assert(lhs.cols() == rhs.rows());
50
51 int rows = result->rows();
52 int cols = result->cols();
53 int depth = lhs.cols();
54
55 if (rows == 0 || cols == 0 || depth == 0) {
56 // Vacuous GEMM, return early to avoid having to deal with
57 // zero sizes below.
58 return;
59 }
60
61 if (cols == 1) {
62 if (IsRequantizationWorthIt(rows, cols)) {
63 typedef DefaultKernel<KernelFamily::Gemv, BitDepthParams> Kernel;
64 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
65 BitDepthParams>(context, Kernel(), lhs, rhs, result,
66 lhs_offset, rhs_offset, output_pipeline);
67 } else {
68 typedef DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>
69 Kernel;
70 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
71 DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
72 result, lhs_offset, rhs_offset,
73 output_pipeline);
74 }
75 } else {
76 if (IsRequantizationWorthIt(rows, cols)) {
77 typedef DefaultKernel<KernelFamily::Gemm, BitDepthParams> Kernel;
78 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
79 BitDepthParams>(context, Kernel(), lhs, rhs, result,
80 lhs_offset, rhs_offset, output_pipeline);
81 } else {
82 typedef DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>
83 Kernel;
84 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
85 DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
86 result, lhs_offset, rhs_offset,
87 output_pipeline);
88 }
89 }
90 }
91
92 // Computes a general matrix product ("GEMM").
93 // This is the legacy version that does not support per channel quantization.
94 // The meaning of the offsets, result_mult_int and result_shift
95 // parameters is the same as in the standard EightBitIntGemm interface
96 // (which is also implemented in the eight_bit_int_gemm directory).
97 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
98 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
99 typename OutputPipelineType>
GemmWithOutputPipeline(GemmContext * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,int lhs_offset,int rhs_offset,const OutputPipelineType & output_pipeline)100 void GemmWithOutputPipeline(GemmContext* context,
101 const MatrixMap<const InputScalar, LhsOrder>& lhs,
102 const MatrixMap<const InputScalar, RhsOrder>& rhs,
103 MatrixMap<OutputScalar, ResultOrder>* result,
104 int lhs_offset, int rhs_offset,
105 const OutputPipelineType& output_pipeline) {
106 const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
107 const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
108 GemmWithOutputPipelinePC<InputScalar, OutputScalar, BitDepthParams>(
109 context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
110 output_pipeline);
111 }
112
113 // Computes a general matrix product ("GEMM").
114 // The meaning of the offsets, result_mult_int and result_shift
115 // parameters is the same as in the standard EightBitIntGemm interface
116 // (which is also implemented in the eight_bit_int_gemm directory).
117 template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
118 MapOrder RhsOrder, MapOrder ResultOrder>
Gemm(GemmContext * context,const MatrixMap<const Scalar,LhsOrder> & lhs,const MatrixMap<const Scalar,RhsOrder> & rhs,MatrixMap<Scalar,ResultOrder> * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift)119 void Gemm(GemmContext* context, const MatrixMap<const Scalar, LhsOrder>& lhs,
120 const MatrixMap<const Scalar, RhsOrder>& rhs,
121 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
122 int rhs_offset, int result_offset, int result_mult_int,
123 int result_shift) {
124 GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
125 context, lhs, rhs, result, lhs_offset, rhs_offset,
126 MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
127 }
128
129 } // namespace gemmlowp
130
131 #endif // GEMMLOWP_PUBLIC_GEMMLOWP_H_
132