1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // output_stages.h: public definitions of the output stages that can
16 // be assembled into an output pipeline, to control how internal
17 // 32-bit accumulators are transformed to obtain the final uint8
18 // result matrix entries.
19
20 #ifndef GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_
21 #define GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_
22
23 #include <tuple>
24
25 #include "../internal/common.h"
26
27 namespace gemmlowp {
28
29 // This output stage takes int32 values and returns still int32 values,
30 // but "quantized down" to the uint8 scale; in other words, its output
31 // is typically what one would then clamp to [0..255] and cast to uint8
32 // (see OutputStageSaturatingCastToUint8).
33 //
34 // This "quantization down" process depends on 3 parameters,
35 // result_offset, result_mult_int, result_shift,
36 // and the result is:
37 // ((input + result_offset) * result_mult_int + rounding) >> result_shift
38 // where
39 // rounding = (result_shift < 1) ? 0 : (1 << (result_shift - 1));
40 struct OutputStageQuantizeDownInt32ToUint8Scale {
41 std::int32_t result_offset;
42 std::int32_t result_mult_int;
43 std::int32_t result_shift;
44 };
45
46 // This output stage takes int32 values and returns still int32 values,
47 // but "quantized down" to the uint8 scale; in other words, its output
48 // is typically what one would then clamp to [0..255] and cast to uint8
49 // (see OutputStageSaturatingCastToUint8).
50 //
51 // This "quantization down" process depends on 3 parameters,
52 // result_offset, result_mult_int, result_shift,
53 // and the result is:
54 // ((input + result_offset) * result_mult_int + rounding) >> result_shift
55 // where
56 // rounding = (result_shift < 1) ? 0 : (1 << (result_shift - 1));
57 //
58 // Difference from OutputStageQuantizeDownInt32ToUint8Scale here is that each
59 // row or column of the output (depending on tShape) has its own result_offset
60 // and result_mult_int numbers.
61 template <VectorShape tShape>
62 struct OutputStageQuantizeDownInt32ToUint8ScalePC {
63 VectorMap<const std::int32_t, tShape> result_offset;
64 VectorMap<const std::int32_t, tShape> result_mult_int;
65 std::int32_t result_shift;
66 };
67
68 // This output stage takes int32 values that are expected to be already
69 // on the final uint8 scale, but not necessarily in the [0..255] range.
70 // It clamps them to the [0..255] range and returns them casted to uint8.
71 struct OutputStageSaturatingCastToUint8 {};
72
73 // This output stage depends on a "bias vector" that should contain int32
74 // entries, and be either a row-vector of the same number of columns as the
75 // result matrix, or a column-vector of the same number of rows as the
76 // result matrix. This output stage takes int32 values and adds to them
77 // the corresponding entry of the bias vector (broadcasted in the other
78 // direction to fit the matrix's shape), outputting int32 values.
79 template <typename VectorType>
80 struct OutputStageBiasAddition {
81 VectorType bias_vector;
82 };
83
84 // This output stage clamps value between the specified min and max bounds.
85 // It can be used to implement "rectified linear unit" activation functions
86 // in neural networks.
87 struct OutputStageClamp {
88 std::int32_t min;
89 std::int32_t max;
90 };
91
92 struct OutputStageTanh {
93 std::int32_t real_zero_as_int32;
94 std::int32_t real_amplitude_as_int32;
95 };
96
97 // An output pipeline is just a std::tuple of output stages.
98 // This function generates a standard output pipeline consisting of two stages:
99 // OutputStageQuantizeDownInt32ToUint8Scale, OutputStageSaturatingCastToUint8.
100 inline std::tuple<OutputStageQuantizeDownInt32ToUint8Scale,
101 OutputStageSaturatingCastToUint8>
MakeStandardOutputPipeline(std::int32_t result_offset,std::int32_t result_mult_int,std::int32_t result_shift)102 MakeStandardOutputPipeline(std::int32_t result_offset,
103 std::int32_t result_mult_int,
104 std::int32_t result_shift) {
105 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
106 quantize_down_stage.result_offset = result_offset;
107 quantize_down_stage.result_mult_int = result_mult_int;
108 quantize_down_stage.result_shift = result_shift;
109 OutputStageSaturatingCastToUint8 saturating_cast_stage;
110 return std::make_tuple(quantize_down_stage, saturating_cast_stage);
111 }
112
113 // An output pipeline is just a std::tuple of output stages.
114 // This function generates a standard output pipeline consisting of two stages:
115 // OutputStageQuantizeDownInt32ToUint8ScalePC, OutputStageSaturatingCastToUint8.
116 template <VectorShape tShape>
117 inline std::tuple<OutputStageQuantizeDownInt32ToUint8ScalePC<tShape>,
118 OutputStageSaturatingCastToUint8>
MakeStandardOutputPipeline(const VectorMap<const std::int32_t,tShape> & result_offset,const VectorMap<const std::int32_t,tShape> & result_mult_int,std::int32_t result_shift)119 MakeStandardOutputPipeline(const VectorMap<const std::int32_t, tShape>&
120 result_offset,
121 const VectorMap<const std::int32_t, tShape>&
122 result_mult_int,
123 std::int32_t result_shift) {
124 OutputStageQuantizeDownInt32ToUint8ScalePC<tShape> quantize_down_stage;
125 quantize_down_stage.result_offset = result_offset;
126 quantize_down_stage.result_mult_int = result_mult_int;
127 quantize_down_stage.result_shift = result_shift;
128 OutputStageSaturatingCastToUint8 saturating_cast_stage;
129 return std::make_tuple(quantize_down_stage, saturating_cast_stage);
130 }
131
132 } // namespace gemmlowp
133
134 #endif // GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_
135