1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // gemmlowp.h: the main public interface header of gemmlowp.
16 
17 #ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
18 #define GEMMLOWP_PUBLIC_GEMMLOWP_H_
19 #include "../internal/kernel_default.h"
20 #include "../internal/multi_thread_gemm.h"
21 #include "../internal/unpack.h"
22 #include "bit_depth.h"
23 #include "map.h"
24 #include "output_stages.h"
25 
26 namespace gemmlowp {
27 
IsRequantizationWorthIt(int rows,int cols)28 inline bool IsRequantizationWorthIt(int rows, int cols) {
29   // We pack depth*(rows+cols) and compute depth*rows*cols.
30   // Thus the ratio of compute/packing cost is rows*cols/(rows+cols)
31   // In the square case rows==cols==N, it becomes N/2.
32   return 2 * rows * cols >= (rows + cols) * kMinimumWidthForRequantization;
33 }
34 
35 class GemmContext : public MultiThreadGemmContext {};
36 
37 // Computes a general matrix product ("GEMM").
38 // This is a version that supports per channel quantization.
39 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
40           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
41           typename LhsOffset, typename RhsOffset, typename OutputPipelineType>
GemmWithOutputPipelinePC(GemmContext * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)42 void GemmWithOutputPipelinePC(GemmContext* context,
43                               const MatrixMap<const InputScalar, LhsOrder>& lhs,
44                               const MatrixMap<const InputScalar, RhsOrder>& rhs,
45                               MatrixMap<OutputScalar, ResultOrder>* result,
46                               const LhsOffset& lhs_offset,
47                               const RhsOffset& rhs_offset,
48                               const OutputPipelineType& output_pipeline) {
49   assert(lhs.cols() == rhs.rows());
50 
51   int rows = result->rows();
52   int cols = result->cols();
53   int depth = lhs.cols();
54 
55   if (rows == 0 || cols == 0 || depth == 0) {
56     // Vacuous GEMM, return early to avoid having to deal with
57     // zero sizes below.
58     return;
59   }
60 
61   if (cols == 1) {
62     if (IsRequantizationWorthIt(rows, cols)) {
63       typedef DefaultKernel<KernelFamily::Gemv, BitDepthParams> Kernel;
64       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
65                       BitDepthParams>(context, Kernel(), lhs, rhs, result,
66                                       lhs_offset, rhs_offset, output_pipeline);
67     } else {
68       typedef DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>
69           Kernel;
70       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
71                       DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
72                                                  result, lhs_offset, rhs_offset,
73                                                  output_pipeline);
74     }
75   } else {
76     if (IsRequantizationWorthIt(rows, cols)) {
77       typedef DefaultKernel<KernelFamily::Gemm, BitDepthParams> Kernel;
78       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
79                       BitDepthParams>(context, Kernel(), lhs, rhs, result,
80                                       lhs_offset, rhs_offset, output_pipeline);
81     } else {
82       typedef DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>
83           Kernel;
84       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
85                       DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
86                                                  result, lhs_offset, rhs_offset,
87                                                  output_pipeline);
88     }
89   }
90 }
91 
92 // Computes a general matrix product ("GEMM").
93 // This is the legacy version that does not support per channel quantization.
94 // The meaning of the offsets, result_mult_int and result_shift
95 // parameters is the same as in the standard EightBitIntGemm interface
96 // (which is also implemented in the eight_bit_int_gemm directory).
97 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
98           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
99           typename OutputPipelineType>
GemmWithOutputPipeline(GemmContext * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,int lhs_offset,int rhs_offset,const OutputPipelineType & output_pipeline)100 void GemmWithOutputPipeline(GemmContext* context,
101                             const MatrixMap<const InputScalar, LhsOrder>& lhs,
102                             const MatrixMap<const InputScalar, RhsOrder>& rhs,
103                             MatrixMap<OutputScalar, ResultOrder>* result,
104                             int lhs_offset, int rhs_offset,
105                             const OutputPipelineType& output_pipeline) {
106   const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
107   const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
108   GemmWithOutputPipelinePC<InputScalar, OutputScalar, BitDepthParams>(
109       context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
110       output_pipeline);
111 }
112 
113 // Computes a general matrix product ("GEMM").
114 // The meaning of the offsets, result_mult_int and result_shift
115 // parameters is the same as in the standard EightBitIntGemm interface
116 // (which is also implemented in the eight_bit_int_gemm directory).
117 template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
118           MapOrder RhsOrder, MapOrder ResultOrder>
Gemm(GemmContext * context,const MatrixMap<const Scalar,LhsOrder> & lhs,const MatrixMap<const Scalar,RhsOrder> & rhs,MatrixMap<Scalar,ResultOrder> * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift)119 void Gemm(GemmContext* context, const MatrixMap<const Scalar, LhsOrder>& lhs,
120           const MatrixMap<const Scalar, RhsOrder>& rhs,
121           MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
122           int rhs_offset, int result_offset, int result_mult_int,
123           int result_shift) {
124   GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
125       context, lhs, rhs, result, lhs_offset, rhs_offset,
126       MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
127 }
128 
129 }  // namespace gemmlowp
130 
131 #endif  // GEMMLOWP_PUBLIC_GEMMLOWP_H_
132