1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/elu.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <limits>
22 
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/micro/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace micro {
33 namespace activations {
34 namespace {
35 
36 // OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
37 // of the activation ops below.
38 
39 struct OpData {
40   uint8_t table[256] = {0};
41 };
42 
43 template <typename T>
PopulateLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output,const std::function<float (float)> & transform)44 void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
45                          TfLiteTensor* output,
46                          const std::function<float(float)>& transform) {
47   static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
48   const float inverse_scale = 1 / output->params.scale;
49   int32_t maxval = std::numeric_limits<T>::max();
50   int32_t minval = std::numeric_limits<T>::min();
51   for (int32_t val = minval; val <= maxval; ++val) {
52     const float dequantized =
53         input->params.scale * (val - input->params.zero_point);
54     const float transformed = transform(dequantized);
55     const float rescaled = std::round(transformed * inverse_scale);
56     const int32_t quantized =
57         static_cast<int32_t>(rescaled + output->params.zero_point);
58     data->table[static_cast<uint8_t>(static_cast<T>(val))] =
59         static_cast<uint8_t>(
60             static_cast<T>(std::max(std::min(maxval, quantized), minval)));
61   }
62 }
63 
64 // OLD-TODO(b/143696793): move this to optimized_ops.
EvalUsingLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output)65 void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
66                           TfLiteTensor* output) {
67   const int size =
68       MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
69   uint8_t* output_data = GetTensorData<uint8_t>(output);
70   const uint8_t* input_data = GetTensorData<uint8_t>(input);
71   int i = 0;
72 
73   for (; i < size; ++i) {
74     output_data[i] = data->table[input_data[i]];
75   }
76 }
77 
78 }  // namespace
79 
Init(TfLiteContext * context,const char * buffer,size_t length)80 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81   // This is a builtin op, so we don't use the contents in 'buffer', if any.
82   // Instead, we allocate a new object to carry information from Prepare() to
83   // Eval().
84   return nullptr;
85 }
86 
GenericPrepare(TfLiteContext * context,TfLiteNode * node)87 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
88   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
89   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
90   const TfLiteTensor* input;
91   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
92   TfLiteTensor* output;
93   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
94   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
95 
96   return kTfLiteError;
97 }
98 
EluPrepare(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
100   const TfLiteTensor* input;
101   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
102   TfLiteTensor* output;
103   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
104   OpData* data = reinterpret_cast<OpData*>(node->user_data);
105 
106   // Use LUT to handle quantized elu path.
107   if (input->type == kTfLiteInt8) {
108     PopulateLookupTable<int8_t>(data, input, output, [](float value) {
109       return value < 0.0 ? std::exp(value) - 1.0f : value;
110     });
111   }
112   return GenericPrepare(context, node);
113 }
114 
EluEval(TfLiteContext * context,TfLiteNode * node)115 TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
116   const TfLiteTensor* input;
117   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
118   TfLiteTensor* output;
119   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
120   switch (input->type) {
121     case kTfLiteFloat32: {
122       optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
123                          GetTensorShape(output), GetTensorData<float>(output));
124       return kTfLiteOk;
125     }
126     case kTfLiteInt8: {
127       OpData* data = reinterpret_cast<OpData*>(node->user_data);
128       EvalUsingLookupTable(data, input, output);
129       return kTfLiteOk;
130     }
131     default:
132       TF_LITE_KERNEL_LOG(
133           context, "Only float32 and int8 is supported currently, got %s.",
134           TfLiteTypeGetName(input->type));
135       return kTfLiteError;
136   }
137 }
138 
139 }  // namespace activations
140 
Register_ELU()141 TfLiteRegistration* Register_ELU() { return nullptr; }
142 
143 }  // namespace micro
144 }  // namespace ops
145 }  // namespace tflite
146