1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27 #include "tensorflow/lite/micro/micro_utils.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace micro {
32 namespace activations {
33 namespace {
34 constexpr int kInputTensor = 0;
35 constexpr int kOutputTensor = 0;
36
37 struct OpData {
38 int32_t input_zero_point;
39 int32_t input_range_radius;
40 int32_t input_multiplier;
41 int input_left_shift;
42 };
43
TanhInit(TfLiteContext * context,const char * buffer,size_t length)44 void* TanhInit(TfLiteContext* context, const char* buffer, size_t length) {
45 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
46 return context->AllocatePersistentBuffer(context, sizeof(OpData));
47 }
48
CalculateArithmeticOpData(TfLiteContext * context,TfLiteNode * node,OpData * data)49 TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
50 OpData* data) {
51 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
52 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
54 TF_LITE_ENSURE(context, input != nullptr);
55 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
56 TF_LITE_ENSURE(context, output != nullptr);
57
58 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
59
60 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
61 static constexpr int kInputIntegerBits = 4;
62 const double input_real_multiplier =
63 static_cast<double>(input->params.scale) *
64 static_cast<double>(1 << (31 - kInputIntegerBits));
65
66 const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
67 data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
68
69 data->input_range_radius =
70 CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
71 }
72 return kTfLiteOk;
73 }
74
TanhPrepare(TfLiteContext * context,TfLiteNode * node)75 TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
76 TFLITE_DCHECK(node->user_data != nullptr);
77
78 OpData* data = static_cast<OpData*>(node->user_data);
79
80 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
81 TF_LITE_ENSURE(context, input != nullptr);
82 data->input_zero_point = input->params.zero_point;
83 return CalculateArithmeticOpData(context, node, data);
84 }
85
86 } // namespace
87
TanhEval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
89 const TfLiteEvalTensor* input =
90 tflite::micro::GetEvalInput(context, node, kInputTensor);
91 TfLiteEvalTensor* output =
92 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
93
94 TFLITE_DCHECK(node->user_data != nullptr);
95 const OpData& data = *(static_cast<const OpData*>(node->user_data));
96
97 switch (input->type) {
98 case kTfLiteFloat32: {
99 reference_ops::Tanh(tflite::micro::GetTensorShape(input),
100 tflite::micro::GetTensorData<float>(input),
101 tflite::micro::GetTensorShape(output),
102 tflite::micro::GetTensorData<float>(output));
103 return kTfLiteOk;
104 } break;
105 case kTfLiteInt16: {
106 TanhParams params;
107 params.input_left_shift = data.input_left_shift;
108 reference_ops::Tanh(params, tflite::micro::GetTensorShape(input),
109 tflite::micro::GetTensorData<int16_t>(input),
110 tflite::micro::GetTensorShape(output),
111 tflite::micro::GetTensorData<int16_t>(output));
112 return kTfLiteOk;
113 } break;
114 case kTfLiteUInt8: {
115 TanhParams params;
116 params.input_zero_point = data.input_zero_point;
117 params.input_range_radius = data.input_range_radius;
118 params.input_multiplier = data.input_multiplier;
119 params.input_left_shift = data.input_left_shift;
120 reference_ops::Tanh(params, tflite::micro::GetTensorShape(input),
121 tflite::micro::GetTensorData<uint8_t>(input),
122 tflite::micro::GetTensorShape(output),
123 tflite::micro::GetTensorData<uint8_t>(output));
124
125 return kTfLiteOk;
126 } break;
127 case kTfLiteInt8: {
128 reference_integer_ops::Tanh(
129 data.input_zero_point, data.input_range_radius, data.input_multiplier,
130 data.input_left_shift, tflite::micro::GetTensorShape(input),
131 tflite::micro::GetTensorData<int8_t>(input),
132 tflite::micro::GetTensorShape(output),
133 tflite::micro::GetTensorData<int8_t>(output));
134 return kTfLiteOk;
135 } break;
136 default:
137 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
138 TfLiteTypeGetName(input->type),
139 TfLiteTypeGetName(output->type));
140 return kTfLiteError;
141 }
142 }
143
144 } // namespace activations
145
Register_TANH()146 TfLiteRegistration Register_TANH() {
147 return {/*init=*/activations::TanhInit,
148 /*free=*/nullptr,
149 /*prepare=*/activations::TanhPrepare,
150 /*invoke=*/activations::TanhEval,
151 /*profiling_string=*/nullptr,
152 /*builtin_code=*/0,
153 /*custom_name=*/nullptr,
154 /*version=*/0};
155 }
156 } // namespace micro
157 } // namespace ops
158 } // namespace tflite
159