1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace pad {
29 
30 // This file has two implementations of Pad.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 struct PadContext {
PadContexttflite::ops::builtin::pad::PadContext37   PadContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, 0);
39     paddings = GetInput(context, node, 1);
40     if (NumInputs(node) == 3) {
41       constant_values = GetOptionalInputTensor(context, node, 2);
42     } else {
43       constant_values = nullptr;
44     }
45     output = GetOutput(context, node, 0);
46     dims = NumDimensions(input);
47 
48     resizing_category = ResizingCategory::kGenericResize;
49     const int paddings_total = GetTensorShape(paddings).FlatSize();
50     const int32* paddings_data = GetTensorData<int32>(paddings);
51     // Paddings will be a n,2 array, and we need to detect 4D arrays with the
52     // pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
53     if (IsConstantTensor(paddings) && paddings_total == 8 &&
54         (paddings_data[0] == 0 && paddings_data[1] == 0) &&
55         (paddings_data[6] == 0 && paddings_data[7] == 0)) {
56       resizing_category = ResizingCategory::kImageStyle;
57     }
58   }
59   const TfLiteTensor* constant_values;
60   const TfLiteTensor* input;
61   const TfLiteTensor* paddings;
62   TfLiteTensor* output;
63   int dims;
64   ResizingCategory resizing_category;
65 };
66 
67 // Resizes output array based on the input size and padding size. This function
68 // is callable from both Prepare() and Eval() as long as the caller ensures the
69 // paddings data is present.
ResizeOutputTensor(TfLiteContext * context,PadContext * op_context)70 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
71                                 PadContext* op_context) {
72   // Ensures the paddings array is dims x 2.
73   TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
74                     op_context->dims);
75   TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
76 
77   // Determines the size of the output tensor.
78   TfLiteIntArray* input_size = op_context->input->dims;
79   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
80   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
81 
82   for (int idx = 0; idx < op_context->dims; ++idx) {
83     int before_padding = *paddings_data++;
84     int after_padding = *paddings_data++;
85 
86     TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
87                        "Pad value has to be greater than equal to 0.");
88 
89     output_size->data[idx] =
90         (input_size->data[idx] + before_padding + after_padding);
91   }
92 
93   return context->ResizeTensor(context, op_context->output, output_size);
94 }
95 
Prepare(TfLiteContext * context,TfLiteNode * node)96 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
97   TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
98   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
99 
100   PadContext op_context(context, node);
101   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
102   if (op_context.constant_values != nullptr) {
103     TF_LITE_ENSURE_EQ(context, op_context.input->type,
104                       op_context.constant_values->type);
105   }
106 
107   // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
108   TF_LITE_ENSURE(context, op_context.dims <= 4);
109 
110   // Exit early if paddings is a non-const tensor. Set output tensor to
111   // dynamic so output size can be determined in Eval.
112   if (!IsConstantTensor(op_context.paddings)) {
113     SetTensorToDynamic(op_context.output);
114     return kTfLiteOk;
115   }
116   return ResizeOutputTensor(context, &op_context);
117 }
118 
119 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
121   PadContext op_context(context, node);
122 
123   if (op_context.constant_values != nullptr) {
124     // Ensure that constant_values is a scalar.
125     TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
126   }
127 
128   // Resize the output tensor if the output tensor is dynamic.
129   if (IsDynamicTensor(op_context.output)) {
130     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
131   }
132 
133   // TODO(nupurgarg): Change kernel implementation to take in int* instead of
134   // vector<int> to remove malloc from Eval().
135   // Create before and after padding arrays that are accepted by the kernel.
136   std::vector<int> before_padding;
137   std::vector<int> after_padding;
138   const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
139 
140   // TODO(nupurgarg): Change kernel implementation to use padding arrays in
141   // forward order (depth, width, height, batch).
142   // Build paddings in order of int[] = {batch, height, width, depth} to match
143   // kernel implementation of Pad in reference_ops.h and optimized_ops.h.
144   for (int idx = op_context.dims - 1; idx >= 0; --idx) {
145     before_padding.push_back(paddings_data[idx * 2]);
146     after_padding.push_back(paddings_data[idx * 2 + 1]);
147   }
148 
149 #define TF_LITE_PAD(type, op_name, scalar, pad_value)                     \
150   TF_LITE_ENSURE(context, before_padding.size() <= 4);                    \
151   TF_LITE_ENSURE(context, after_padding.size() <= 4);                     \
152   tflite::PadParams op_params;                                            \
153   op_params.left_padding_count = before_padding.size();                   \
154   op_params.right_padding_count = after_padding.size();                   \
155   for (int i = 0; i < op_context.dims; ++i) {                             \
156     op_params.left_padding[i] = before_padding[op_context.dims - 1 - i];  \
157     op_params.right_padding[i] = after_padding[op_context.dims - 1 - i];  \
158   }                                                                       \
159   const scalar pad_value_copy = pad_value;                                \
160                                                                           \
161   type::op_name(op_params, GetTensorShape(op_context.input),              \
162                 GetTensorData<scalar>(op_context.input), &pad_value_copy, \
163                 GetTensorShape(op_context.output),                        \
164                 GetTensorData<scalar>(op_context.output))
165   switch (op_context.input->type) {
166     case kTfLiteFloat32: {
167       float pad_value = op_context.constant_values == nullptr
168                             ? 0.f
169                             : *GetTensorData<float>(op_context.constant_values);
170       if (kernel_type == kReference) {
171         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
172           TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
173         } else {
174           TF_LITE_PAD(reference_ops, Pad, float, pad_value);
175         }
176       } else if (kernel_type == kGenericOptimized) {
177         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
178           TF_LITE_PAD(optimized_ops, PadImageStyle, float, pad_value);
179         } else {
180           TF_LITE_PAD(optimized_ops, Pad, float, pad_value);
181         }
182       }
183     } break;
184     case kTfLiteUInt8: {
185       uint8_t pad_value;
186       if (op_context.constant_values == nullptr) {
187         // Quantized Pad requires that 0 is represented in the quantized
188         // range.
189         TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
190                                     std::numeric_limits<uint8_t>::min());
191         TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
192                                     std::numeric_limits<uint8_t>::max());
193         pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
194       } else {
195         // Quantized Pad requires that 'constant_values' is represented in the
196         // same quantized range as the input and output tensors.
197         TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
198                           op_context.constant_values->params.zero_point);
199         TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
200                           op_context.constant_values->params.scale);
201         pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
202       }
203       if (kernel_type == kReference) {
204         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
205           TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
206         } else {
207           TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
208         }
209       } else if (kernel_type == kGenericOptimized) {
210         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
211           TF_LITE_PAD(optimized_ops, PadImageStyle, uint8_t, pad_value);
212         } else {
213           TF_LITE_PAD(optimized_ops, Pad, uint8_t, pad_value);
214         }
215       }
216     } break;
217     case kTfLiteInt8: {
218       int8_t pad_value;
219       if (op_context.constant_values == nullptr) {
220         // Quantized Pad requires that 0 is represented in the quantized
221         // range.
222         TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
223                                     std::numeric_limits<int8_t>::min());
224         TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
225                                     std::numeric_limits<int8_t>::max());
226         pad_value = static_cast<int8_t>(op_context.output->params.zero_point);
227       } else {
228         // Quantized Pad requires that 'constant_values' is represented in the
229         // same quantized range as the input and output tensors.
230         TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
231                           op_context.constant_values->params.zero_point);
232         TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
233                           op_context.constant_values->params.scale);
234         pad_value = *GetTensorData<int8_t>(op_context.constant_values);
235       }
236       if (op_context.resizing_category == ResizingCategory::kImageStyle) {
237         TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value);
238       } else {
239         TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value);
240       }
241     } break;
242     case kTfLiteInt32: {
243       int32_t pad_value =
244           op_context.constant_values == nullptr
245               ? 0
246               : *GetTensorData<int32_t>(op_context.constant_values);
247       if (kernel_type == kReference) {
248         TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
249       } else if (kernel_type == kGenericOptimized) {
250         TF_LITE_PAD(optimized_ops, Pad, int32_t, pad_value);
251       }
252     } break;
253     case kTfLiteInt64: {
254       int64_t pad_value =
255           op_context.constant_values == nullptr
256               ? 0L
257               : *GetTensorData<int64_t>(op_context.constant_values);
258       if (kernel_type == kReference) {
259         TF_LITE_PAD(reference_ops, Pad, int64_t, pad_value);
260       } else if (kernel_type == kGenericOptimized) {
261         TF_LITE_PAD(optimized_ops, Pad, int64_t, pad_value);
262       }
263     } break;
264     default:
265       context->ReportError(context,
266                            "Type %d is currently not supported by Pad.",
267                            op_context.input->type);
268       return kTfLiteError;
269   }
270 #undef TF_LITE_PAD
271   return kTfLiteOk;
272 }
273 
274 }  // namespace pad
275 
Register_PAD_REF()276 TfLiteRegistration* Register_PAD_REF() {
277   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
278                                  pad::Eval<pad::kReference>};
279   return &r;
280 }
281 
Register_PAD_GENERIC_OPT()282 TfLiteRegistration* Register_PAD_GENERIC_OPT() {
283   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
284                                  pad::Eval<pad::kGenericOptimized>};
285   return &r;
286 }
287 
Register_PAD()288 TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
289 
290 // Also register Pad as PadV2.
Register_PADV2_REF()291 TfLiteRegistration* Register_PADV2_REF() {
292   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
293                                  pad::Eval<pad::kReference>};
294   return &r;
295 }
296 
Register_PADV2_GENERIC_OPT()297 TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
298   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
299                                  pad::Eval<pad::kGenericOptimized>};
300   return &r;
301 }
302 
Register_PADV2()303 TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
304 
305 }  // namespace builtin
306 }  // namespace ops
307 }  // namespace tflite
308