1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 namespace one_hot {
27 
28 constexpr int kIndicesTensor = 0;
29 constexpr int kDepthTensor = 1;
30 constexpr int kOnValueTensor = 2;
31 constexpr int kOffValueTensor = 3;
32 constexpr int kOutputTensor = 0;
33 
34 // Convenience utility for destructuring a node into the appropriate tensors and
35 // data for the op. Note that this destructuring is quite cheap, so we can avoid
36 // allocating op-specific, persistent data on the heap.
37 struct OneHotContext {
OneHotContexttflite::ops::builtin::one_hot::OneHotContext38   OneHotContext(TfLiteContext* context, TfLiteNode* node) {
39     indices = GetInput(context, node, kIndicesTensor);
40     depth = GetInput(context, node, kDepthTensor);
41     on_value = GetInput(context, node, kOnValueTensor);
42     off_value = GetInput(context, node, kOffValueTensor);
43     output = GetOutput(context, node, kOutputTensor);
44 
45     const auto* params =
46         reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
47     const int indices_dims = indices->dims->size;
48     axis = (params->axis == -1) ? indices_dims : params->axis;
49     output_dims = indices_dims + 1;
50     dtype = on_value->type;
51   }
52 
53   const TfLiteTensor* indices;
54   const TfLiteTensor* depth;
55   const TfLiteTensor* on_value;
56   const TfLiteTensor* off_value;
57   TfLiteTensor* output;
58   int axis;
59   int output_dims;
60   TfLiteType dtype;
61 };
62 
63 template <typename T, typename TI>
OneHotComputeImpl(const OneHotContext & op_context)64 void OneHotComputeImpl(const OneHotContext& op_context) {
65   // prefix_dim_size == # of elements before the axis
66   // depth == # of elements per axis
67   // suffix_dim_size == # of elements after the axis
68   int prefix_dim_size = 1;
69   for (int i = 0; i < op_context.axis; ++i) {
70     prefix_dim_size *= op_context.indices->dims->data[i];
71   }
72   const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
73   const int depth = *op_context.depth->data.i32;
74 
75   const T on_value = *GetTensorData<T>(op_context.on_value);
76   const T off_value = *GetTensorData<T>(op_context.off_value);
77 
78   // View the indices as a matrix of size:
79   //     prefix_dim_size x suffix_dim_size
80   // View the output as a matrix of size:
81   //     prefix_dim_size x depth x suffix_dim_size
82   // Then the output is:
83   //     output(i, j, k) == (indices(i, k) == j) ? on : off
84   T* output = GetTensorData<T>(op_context.output);
85   const TI* indices = GetTensorData<TI>(op_context.indices);
86   for (int i = 0; i < prefix_dim_size; ++i) {
87     for (int j = 0; j < depth; ++j) {
88       for (int k = 0; k < suffix_dim_size; ++k, ++output) {
89         *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
90                       ? on_value
91                       : off_value;
92       }
93     }
94   }
95 }
96 
97 template <typename T>
OneHotCompute(const OneHotContext & op_context)98 void OneHotCompute(const OneHotContext& op_context) {
99   if (op_context.indices->type == kTfLiteInt64) {
100     OneHotComputeImpl<T, int64_t>(op_context);
101   } else {
102     OneHotComputeImpl<T, int>(op_context);
103   }
104 }
105 
ResizeOutputTensor(TfLiteContext * context,const OneHotContext & op_context)106 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
107                                 const OneHotContext& op_context) {
108   TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
109   TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
110   for (int i = 0; i < op_context.output_dims; ++i) {
111     if (i < op_context.axis) {
112       output_size->data[i] = op_context.indices->dims->data[i];
113     } else if (i == op_context.axis) {
114       output_size->data[i] = *op_context.depth->data.i32;
115     } else {
116       output_size->data[i] = op_context.indices->dims->data[i - 1];
117     }
118   }
119   return context->ResizeTensor(context, op_context.output, output_size);
120 }
121 
Prepare(TfLiteContext * context,TfLiteNode * node)122 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
123   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
124   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
125 
126   OneHotContext op_context{context, node};
127   switch (op_context.dtype) {
128     // TODO(b/111744875): Support uint8 and quantization.
129     case kTfLiteFloat32:
130     case kTfLiteInt16:
131     case kTfLiteInt32:
132     case kTfLiteInt64:
133     case kTfLiteInt8:
134     case kTfLiteUInt8:
135     case kTfLiteBool:
136       op_context.output->type = op_context.dtype;
137       break;
138     default:
139       TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
140                          TfLiteTypeGetName(op_context.dtype));
141       return kTfLiteError;
142   }
143 
144   TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
145                               op_context.indices->type == kTfLiteInt64);
146   TF_LITE_ENSURE(context, op_context.axis >= 0 &&
147                               op_context.axis < op_context.output_dims);
148   TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
149   TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
150   TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
151   TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
152   TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
153                           op_context.dtype);
154 
155   if (!IsConstantTensor(op_context.depth)) {
156     SetTensorToDynamic(op_context.output);
157     return kTfLiteOk;
158   }
159 
160   return ResizeOutputTensor(context, op_context);
161 }
162 
Eval(TfLiteContext * context,TfLiteNode * node)163 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
164   OneHotContext op_context{context, node};
165 
166   if (IsDynamicTensor(op_context.output)) {
167     ResizeOutputTensor(context, op_context);
168   }
169 
170   switch (op_context.output->type) {
171     case kTfLiteFloat32:
172       OneHotCompute<float>(op_context);
173       break;
174     case kTfLiteInt32:
175       OneHotCompute<int>(op_context);
176       break;
177     case kTfLiteInt64:
178       OneHotCompute<int64_t>(op_context);
179       break;
180     case kTfLiteInt8:
181       OneHotCompute<int8_t>(op_context);
182       break;
183     case kTfLiteUInt8:
184       OneHotCompute<uint8_t>(op_context);
185       break;
186     case kTfLiteBool:
187       OneHotCompute<bool>(op_context);
188       break;
189     default:
190       return kTfLiteError;
191   }
192 
193   return kTfLiteOk;
194 }
195 
196 }  // namespace one_hot
197 
Register_ONE_HOT()198 TfLiteRegistration* Register_ONE_HOT() {
199   static TfLiteRegistration r = {
200       nullptr,
201       nullptr,
202       one_hot::Prepare,
203       one_hot::Eval,
204   };
205   return &r;
206 }
207 
208 }  // namespace builtin
209 }  // namespace ops
210 }  // namespace tflite
211