1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/elu.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <limits>
22
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/micro/kernels/kernel_util.h"
29
30 namespace tflite {
31 namespace ops {
32 namespace micro {
33 namespace activations {
34 namespace {
35
36 // OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
37 // of the activation ops below.
38
39 struct OpData {
40 uint8_t table[256] = {0};
41 };
42
43 template <typename T>
PopulateLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output,const std::function<float (float)> & transform)44 void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
45 TfLiteTensor* output,
46 const std::function<float(float)>& transform) {
47 static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
48 const float inverse_scale = 1 / output->params.scale;
49 int32_t maxval = std::numeric_limits<T>::max();
50 int32_t minval = std::numeric_limits<T>::min();
51 for (int32_t val = minval; val <= maxval; ++val) {
52 const float dequantized =
53 input->params.scale * (val - input->params.zero_point);
54 const float transformed = transform(dequantized);
55 const float rescaled = std::round(transformed * inverse_scale);
56 const int32_t quantized =
57 static_cast<int32_t>(rescaled + output->params.zero_point);
58 data->table[static_cast<uint8_t>(static_cast<T>(val))] =
59 static_cast<uint8_t>(
60 static_cast<T>(std::max(std::min(maxval, quantized), minval)));
61 }
62 }
63
64 // OLD-TODO(b/143696793): move this to optimized_ops.
EvalUsingLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output)65 void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
66 TfLiteTensor* output) {
67 const int size =
68 MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
69 uint8_t* output_data = GetTensorData<uint8_t>(output);
70 const uint8_t* input_data = GetTensorData<uint8_t>(input);
71 int i = 0;
72
73 for (; i < size; ++i) {
74 output_data[i] = data->table[input_data[i]];
75 }
76 }
77
78 } // namespace
79
Init(TfLiteContext * context,const char * buffer,size_t length)80 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81 // This is a builtin op, so we don't use the contents in 'buffer', if any.
82 // Instead, we allocate a new object to carry information from Prepare() to
83 // Eval().
84 return nullptr;
85 }
86
GenericPrepare(TfLiteContext * context,TfLiteNode * node)87 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
88 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
89 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
90 const TfLiteTensor* input;
91 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
92 TfLiteTensor* output;
93 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
94 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
95
96 return kTfLiteError;
97 }
98
EluPrepare(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
100 const TfLiteTensor* input;
101 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
102 TfLiteTensor* output;
103 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
104 OpData* data = reinterpret_cast<OpData*>(node->user_data);
105
106 // Use LUT to handle quantized elu path.
107 if (input->type == kTfLiteInt8) {
108 PopulateLookupTable<int8_t>(data, input, output, [](float value) {
109 return value < 0.0 ? std::exp(value) - 1.0f : value;
110 });
111 }
112 return GenericPrepare(context, node);
113 }
114
EluEval(TfLiteContext * context,TfLiteNode * node)115 TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
116 const TfLiteTensor* input;
117 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
118 TfLiteTensor* output;
119 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
120 switch (input->type) {
121 case kTfLiteFloat32: {
122 optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
123 GetTensorShape(output), GetTensorData<float>(output));
124 return kTfLiteOk;
125 }
126 case kTfLiteInt8: {
127 OpData* data = reinterpret_cast<OpData*>(node->user_data);
128 EvalUsingLookupTable(data, input, output);
129 return kTfLiteOk;
130 }
131 default:
132 TF_LITE_KERNEL_LOG(
133 context, "Only float32 and int8 is supported currently, got %s.",
134 TfLiteTypeGetName(input->type));
135 return kTfLiteError;
136 }
137 }
138
139 } // namespace activations
140
Register_ELU()141 TfLiteRegistration* Register_ELU() { return nullptr; }
142
143 } // namespace micro
144 } // namespace ops
145 } // namespace tflite
146