1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 namespace one_hot {
27
28 constexpr int kIndicesTensor = 0;
29 constexpr int kDepthTensor = 1;
30 constexpr int kOnValueTensor = 2;
31 constexpr int kOffValueTensor = 3;
32 constexpr int kOutputTensor = 0;
33
34 // Convenience utility for destructuring a node into the appropriate tensors and
35 // data for the op. Note that this destructuring is quite cheap, so we can avoid
36 // allocating op-specific, persistent data on the heap.
37 struct OneHotContext {
OneHotContexttflite::ops::builtin::one_hot::OneHotContext38 OneHotContext(TfLiteContext* context, TfLiteNode* node) {
39 indices = GetInput(context, node, kIndicesTensor);
40 depth = GetInput(context, node, kDepthTensor);
41 on_value = GetInput(context, node, kOnValueTensor);
42 off_value = GetInput(context, node, kOffValueTensor);
43 output = GetOutput(context, node, kOutputTensor);
44
45 const auto* params =
46 reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
47 const int indices_dims = indices->dims->size;
48 axis = (params->axis == -1) ? indices_dims : params->axis;
49 output_dims = indices_dims + 1;
50 dtype = on_value->type;
51 }
52
53 const TfLiteTensor* indices;
54 const TfLiteTensor* depth;
55 const TfLiteTensor* on_value;
56 const TfLiteTensor* off_value;
57 TfLiteTensor* output;
58 int axis;
59 int output_dims;
60 TfLiteType dtype;
61 };
62
63 template <typename T, typename TI>
OneHotComputeImpl(const OneHotContext & op_context)64 void OneHotComputeImpl(const OneHotContext& op_context) {
65 // prefix_dim_size == # of elements before the axis
66 // depth == # of elements per axis
67 // suffix_dim_size == # of elements after the axis
68 int prefix_dim_size = 1;
69 for (int i = 0; i < op_context.axis; ++i) {
70 prefix_dim_size *= op_context.indices->dims->data[i];
71 }
72 const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
73 const int depth = *op_context.depth->data.i32;
74
75 const T on_value = *GetTensorData<T>(op_context.on_value);
76 const T off_value = *GetTensorData<T>(op_context.off_value);
77
78 // View the indices as a matrix of size:
79 // prefix_dim_size x suffix_dim_size
80 // View the output as a matrix of size:
81 // prefix_dim_size x depth x suffix_dim_size
82 // Then the output is:
83 // output(i, j, k) == (indices(i, k) == j) ? on : off
84 T* output = GetTensorData<T>(op_context.output);
85 const TI* indices = GetTensorData<TI>(op_context.indices);
86 for (int i = 0; i < prefix_dim_size; ++i) {
87 for (int j = 0; j < depth; ++j) {
88 for (int k = 0; k < suffix_dim_size; ++k, ++output) {
89 *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
90 ? on_value
91 : off_value;
92 }
93 }
94 }
95 }
96
97 template <typename T>
OneHotCompute(const OneHotContext & op_context)98 void OneHotCompute(const OneHotContext& op_context) {
99 if (op_context.indices->type == kTfLiteInt64) {
100 OneHotComputeImpl<T, int64_t>(op_context);
101 } else {
102 OneHotComputeImpl<T, int>(op_context);
103 }
104 }
105
ResizeOutputTensor(TfLiteContext * context,const OneHotContext & op_context)106 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
107 const OneHotContext& op_context) {
108 TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
109 TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
110 for (int i = 0; i < op_context.output_dims; ++i) {
111 if (i < op_context.axis) {
112 output_size->data[i] = op_context.indices->dims->data[i];
113 } else if (i == op_context.axis) {
114 output_size->data[i] = *op_context.depth->data.i32;
115 } else {
116 output_size->data[i] = op_context.indices->dims->data[i - 1];
117 }
118 }
119 return context->ResizeTensor(context, op_context.output, output_size);
120 }
121
Prepare(TfLiteContext * context,TfLiteNode * node)122 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
123 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
124 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
125
126 OneHotContext op_context{context, node};
127 switch (op_context.dtype) {
128 // TODO(b/111744875): Support uint8 and quantization.
129 case kTfLiteFloat32:
130 case kTfLiteInt16:
131 case kTfLiteInt32:
132 case kTfLiteInt64:
133 case kTfLiteInt8:
134 case kTfLiteUInt8:
135 case kTfLiteBool:
136 op_context.output->type = op_context.dtype;
137 break;
138 default:
139 TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
140 TfLiteTypeGetName(op_context.dtype));
141 return kTfLiteError;
142 }
143
144 TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
145 op_context.indices->type == kTfLiteInt64);
146 TF_LITE_ENSURE(context, op_context.axis >= 0 &&
147 op_context.axis < op_context.output_dims);
148 TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
149 TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
150 TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
151 TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
152 TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
153 op_context.dtype);
154
155 if (!IsConstantTensor(op_context.depth)) {
156 SetTensorToDynamic(op_context.output);
157 return kTfLiteOk;
158 }
159
160 return ResizeOutputTensor(context, op_context);
161 }
162
Eval(TfLiteContext * context,TfLiteNode * node)163 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
164 OneHotContext op_context{context, node};
165
166 if (IsDynamicTensor(op_context.output)) {
167 ResizeOutputTensor(context, op_context);
168 }
169
170 switch (op_context.output->type) {
171 case kTfLiteFloat32:
172 OneHotCompute<float>(op_context);
173 break;
174 case kTfLiteInt32:
175 OneHotCompute<int>(op_context);
176 break;
177 case kTfLiteInt64:
178 OneHotCompute<int64_t>(op_context);
179 break;
180 case kTfLiteInt8:
181 OneHotCompute<int8_t>(op_context);
182 break;
183 case kTfLiteUInt8:
184 OneHotCompute<uint8_t>(op_context);
185 break;
186 case kTfLiteBool:
187 OneHotCompute<bool>(op_context);
188 break;
189 default:
190 return kTfLiteError;
191 }
192
193 return kTfLiteOk;
194 }
195
196 } // namespace one_hot
197
Register_ONE_HOT()198 TfLiteRegistration* Register_ONE_HOT() {
199 static TfLiteRegistration r = {
200 nullptr,
201 nullptr,
202 one_hot::Prepare,
203 one_hot::Eval,
204 };
205 return &r;
206 }
207
208 } // namespace builtin
209 } // namespace ops
210 } // namespace tflite
211