1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27 #include "tensorflow/lite/micro/micro_utils.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace micro {
32 namespace activations {
33 namespace {
34 constexpr int kInputTensor = 0;
35 constexpr int kOutputTensor = 0;
36 
37 struct OpData {
38   int32_t input_zero_point;
39   int32_t input_range_radius;
40   int32_t input_multiplier;
41   int input_left_shift;
42 };
43 
TanhInit(TfLiteContext * context,const char * buffer,size_t length)44 void* TanhInit(TfLiteContext* context, const char* buffer, size_t length) {
45   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
46   return context->AllocatePersistentBuffer(context, sizeof(OpData));
47 }
48 
CalculateArithmeticOpData(TfLiteContext * context,TfLiteNode * node,OpData * data)49 TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
50                                        OpData* data) {
51   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
52   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
54   TF_LITE_ENSURE(context, input != nullptr);
55   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
56   TF_LITE_ENSURE(context, output != nullptr);
57 
58   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
59 
60   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
61     static constexpr int kInputIntegerBits = 4;
62     const double input_real_multiplier =
63         static_cast<double>(input->params.scale) *
64         static_cast<double>(1 << (31 - kInputIntegerBits));
65 
66     const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
67     data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
68 
69     data->input_range_radius =
70         CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
71   }
72   return kTfLiteOk;
73 }
74 
TanhPrepare(TfLiteContext * context,TfLiteNode * node)75 TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
76   TFLITE_DCHECK(node->user_data != nullptr);
77 
78   OpData* data = static_cast<OpData*>(node->user_data);
79 
80   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
81   TF_LITE_ENSURE(context, input != nullptr);
82   data->input_zero_point = input->params.zero_point;
83   return CalculateArithmeticOpData(context, node, data);
84 }
85 
86 }  // namespace
87 
TanhEval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
89   const TfLiteEvalTensor* input =
90       tflite::micro::GetEvalInput(context, node, kInputTensor);
91   TfLiteEvalTensor* output =
92       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
93 
94   TFLITE_DCHECK(node->user_data != nullptr);
95   const OpData& data = *(static_cast<const OpData*>(node->user_data));
96 
97   switch (input->type) {
98     case kTfLiteFloat32: {
99       reference_ops::Tanh(tflite::micro::GetTensorShape(input),
100                           tflite::micro::GetTensorData<float>(input),
101                           tflite::micro::GetTensorShape(output),
102                           tflite::micro::GetTensorData<float>(output));
103       return kTfLiteOk;
104     } break;
105     case kTfLiteInt16: {
106       TanhParams params;
107       params.input_left_shift = data.input_left_shift;
108       reference_ops::Tanh(params, tflite::micro::GetTensorShape(input),
109                           tflite::micro::GetTensorData<int16_t>(input),
110                           tflite::micro::GetTensorShape(output),
111                           tflite::micro::GetTensorData<int16_t>(output));
112       return kTfLiteOk;
113     } break;
114     case kTfLiteUInt8: {
115       TanhParams params;
116       params.input_zero_point = data.input_zero_point;
117       params.input_range_radius = data.input_range_radius;
118       params.input_multiplier = data.input_multiplier;
119       params.input_left_shift = data.input_left_shift;
120       reference_ops::Tanh(params, tflite::micro::GetTensorShape(input),
121                           tflite::micro::GetTensorData<uint8_t>(input),
122                           tflite::micro::GetTensorShape(output),
123                           tflite::micro::GetTensorData<uint8_t>(output));
124 
125       return kTfLiteOk;
126     } break;
127     case kTfLiteInt8: {
128       reference_integer_ops::Tanh(
129           data.input_zero_point, data.input_range_radius, data.input_multiplier,
130           data.input_left_shift, tflite::micro::GetTensorShape(input),
131           tflite::micro::GetTensorData<int8_t>(input),
132           tflite::micro::GetTensorShape(output),
133           tflite::micro::GetTensorData<int8_t>(output));
134       return kTfLiteOk;
135     } break;
136     default:
137       TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
138                          TfLiteTypeGetName(input->type),
139                          TfLiteTypeGetName(output->type));
140       return kTfLiteError;
141   }
142 }
143 
144 }  // namespace activations
145 
Register_TANH()146 TfLiteRegistration Register_TANH() {
147   return {/*init=*/activations::TanhInit,
148           /*free=*/nullptr,
149           /*prepare=*/activations::TanhPrepare,
150           /*invoke=*/activations::TanhEval,
151           /*profiling_string=*/nullptr,
152           /*builtin_code=*/0,
153           /*custom_name=*/nullptr,
154           /*version=*/0};
155 }
156 }  // namespace micro
157 }  // namespace ops
158 }  // namespace tflite
159