1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
17
18 #include "tensorflow/lite/kernels/internal/common.h"
19
20 namespace tflite {
21 namespace reference_integer_ops {
22
LogSoftmax(int32_t input_multiplier,int32_t input_shift,int32_t reverse_multiplier,int32_t reverse_shift,int32_t diff_min,int32_t outer_size,int32_t depth,const int8 * input_data,int8 * output_data)23 inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift,
24 int32_t reverse_multiplier, int32_t reverse_shift,
25 int32_t diff_min, int32_t outer_size, int32_t depth,
26 const int8* input_data, int8* output_data) {
27 static constexpr int8_t kMinInt8 = std::numeric_limits<int8_t>::min();
28 static constexpr int8_t kMaxInt8 = std::numeric_limits<int8_t>::max();
29 static constexpr int32_t kMinInt32 = std::numeric_limits<int32_t>::min();
30
31 // [-16, 0] is mapped to [-128, 127] with 1/16 as scale and 127 as zero
32 // point. This nudges the output to [-255/16, 0].
33 static constexpr int32_t kOutputZeroPoint = 127;
34
35 // All IntegerBits must agree with Prepare function.
36 // Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible.
37 static constexpr int kInputIntegerBits = 5;
38 static constexpr int kAccumulationIntegerBits = 12;
39 static constexpr int kOutputIntegerBits = 4;
40 using F5 = gemmlowp::FixedPoint<int32, kInputIntegerBits>;
41 using F12 = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
42
43 for (int outer_index = 0; outer_index < outer_size; ++outer_index) {
44 int8 max_in_row = kMinInt8;
45 for (int inner_index = 0; inner_index < depth; ++inner_index) {
46 max_in_row =
47 std::max(max_in_row, input_data[outer_index * depth + inner_index]);
48 }
49
50 // Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps.
51 F12 sum_of_exps_in_q12 = F12::FromRaw(0);
52 for (int inner_index = 0; inner_index < depth; ++inner_index) {
53 int32_t input_diff =
54 static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
55 max_in_row;
56 if (input_diff >= diff_min) {
57 const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
58 input_diff, input_multiplier, input_shift);
59 sum_of_exps_in_q12 =
60 sum_of_exps_in_q12 +
61 gemmlowp::Rescale<kAccumulationIntegerBits>(
62 exp_on_negative_values(F5::FromRaw(input_diff_in_q5)));
63 }
64 }
65
66 const int32_t log_sum_of_exps_in_q5 =
67 log_x_for_x_greater_than_or_equal_to_1<kInputIntegerBits>(
68 sum_of_exps_in_q12)
69 .raw();
70
71 // Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is
72 // smallest representable in Q5.26 plus the log_sum_of_exps.
73 const int32_t shifted_log_sum_of_exps_in_q5 =
74 log_sum_of_exps_in_q5 + kMinInt32;
75 const int32_t adjusted_diff_min = std::max(
76 diff_min - 1,
77 MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5,
78 reverse_multiplier, -reverse_shift));
79
80 for (int inner_index = 0; inner_index < depth; ++inner_index) {
81 int32_t input_diff =
82 static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
83 max_in_row;
84 // Note use of > below instead of >= above.
85 if (input_diff > adjusted_diff_min) {
86 const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
87 input_diff, input_multiplier, input_shift);
88
89 // Rescale and downcast.
90 int32_t output_in_q27 =
91 gemmlowp::RoundingDivideByPOT(
92 (input_diff_in_q5 - log_sum_of_exps_in_q5),
93 31 - kInputIntegerBits - kOutputIntegerBits) +
94 kOutputZeroPoint;
95
96 output_in_q27 =
97 std::max(std::min(output_in_q27, static_cast<int32_t>(kMaxInt8)),
98 static_cast<int32_t>(kMinInt8));
99 output_data[outer_index * depth + inner_index] =
100 static_cast<int8_t>(output_in_q27);
101 } else {
102 output_data[outer_index * depth + inner_index] = kMinInt8;
103 }
104 }
105 }
106 }
107
108 } // namespace reference_integer_ops
109 } // namespace tflite
110
111 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_LOG_SOFTMAX_H_
112