1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/compatibility.h"
20 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace resize_bilinear {
32
33 // This file has three implementation of RESIZE_BILINEAR.
34 enum KernelType {
35 kReference,
36 kGenericOptimized, // Neon-free
37 kNeonOptimized,
38 };
39
40 constexpr int kInputTensor = 0;
41 constexpr int kSizeTensor = 1;
42 constexpr int kOutputTensor = 0;
43
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * size,TfLiteTensor * output)44 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
45 const TfLiteTensor* input,
46 const TfLiteTensor* size,
47 TfLiteTensor* output) {
48 const int32* size_data = GetTensorData<int32>(size);
49 // Sanity check, the up/down sampling size should always be positive.
50 TF_LITE_ENSURE(context, size_data[0] > 0);
51 TF_LITE_ENSURE(context, size_data[1] > 0);
52 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
53 output_size->data[0] = input->dims->data[0];
54 output_size->data[1] = size_data[0];
55 output_size->data[2] = size_data[1];
56 output_size->data[3] = input->dims->data[3];
57 return context->ResizeTensor(context, output, output_size);
58 }
59
Prepare(TfLiteContext * context,TfLiteNode * node)60 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
61 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
62 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
63
64 const TfLiteTensor* input;
65 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66 const TfLiteTensor* size;
67 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
68 TfLiteTensor* output;
69 TF_LITE_ENSURE_OK(context,
70 GetOutputSafe(context, node, kOutputTensor, &output));
71
72 // TODO(ahentz): Our current implementations rely on the inputs being 4D.
73 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
74 TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
75
76 TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
77 // ResizeBilinear creates a float tensor even when the input is made of
78 // integers.
79 output->type = input->type;
80
81 if (!IsConstantTensor(size)) {
82 SetTensorToDynamic(output);
83 return kTfLiteOk;
84 }
85
86 // Ensure params are valid.
87 auto* params =
88 reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
89 if (params->half_pixel_centers && params->align_corners) {
90 context->ReportError(
91 context, "If half_pixel_centers is True, align_corners must be False.");
92 return kTfLiteError;
93 }
94
95 return ResizeOutputTensor(context, input, size, output);
96 }
97
98 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
100 auto* params =
101 reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
102
103 const TfLiteTensor* input;
104 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
105 TfLiteTensor* output;
106 TF_LITE_ENSURE_OK(context,
107 GetOutputSafe(context, node, kOutputTensor, &output));
108 const TfLiteTensor* size;
109 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
110
111 if (IsDynamicTensor(output)) {
112 TF_LITE_ENSURE_OK(context,
113 ResizeOutputTensor(context, input, size, output));
114 }
115
116 if (output->type == kTfLiteFloat32) {
117 #define TF_LITE_RESIZE_BILINEAR(type, opname, datatype) \
118 tflite::ResizeBilinearParams op_params; \
119 op_params.align_corners = params->align_corners; \
120 op_params.half_pixel_centers = params->half_pixel_centers; \
121 type::opname(op_params, GetTensorShape(input), \
122 GetTensorData<datatype>(input), GetTensorShape(size), \
123 GetTensorData<int32>(size), GetTensorShape(output), \
124 GetTensorData<datatype>(output))
125
126 if (kernel_type == kReference) {
127 TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, float);
128 }
129 if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
130 TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, float);
131 }
132 } else if (output->type == kTfLiteUInt8) {
133 if (kernel_type == kReference) {
134 TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, uint8_t);
135 }
136 if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
137 TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, uint8_t);
138 }
139 } else if (output->type == kTfLiteInt8) {
140 TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int8_t);
141 } else if (output->type == kTfLiteInt16) {
142 TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int16_t);
143 #undef TF_LITE_RESIZE_BILINEAR
144 } else {
145 context->ReportError(context, "Output type is %d, requires float.",
146 output->type);
147 return kTfLiteError;
148 }
149
150 return kTfLiteOk;
151 }
152
153 } // namespace resize_bilinear
154
Register_RESIZE_BILINEAR_REF()155 TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
156 static TfLiteRegistration r = {
157 nullptr, nullptr, resize_bilinear::Prepare,
158 resize_bilinear::Eval<resize_bilinear::kReference>};
159 return &r;
160 }
161
Register_RESIZE_BILINEAR_GENERIC_OPT()162 TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() {
163 static TfLiteRegistration r = {
164 nullptr, nullptr, resize_bilinear::Prepare,
165 resize_bilinear::Eval<resize_bilinear::kGenericOptimized>};
166 return &r;
167 }
168
Register_RESIZE_BILINEAR_NEON_OPT()169 TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() {
170 static TfLiteRegistration r = {
171 nullptr, nullptr, resize_bilinear::Prepare,
172 resize_bilinear::Eval<resize_bilinear::kNeonOptimized>};
173 return &r;
174 }
175
Register_RESIZE_BILINEAR()176 TfLiteRegistration* Register_RESIZE_BILINEAR() {
177 #ifdef USE_NEON
178 return Register_RESIZE_BILINEAR_NEON_OPT();
179 #else
180 return Register_RESIZE_BILINEAR_GENERIC_OPT();
181 #endif
182 }
183
184 } // namespace builtin
185 } // namespace ops
186 } // namespace tflite
187