1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace reverse_sequence {
29 namespace {
30 
31 constexpr int kInputTensor = 0;
32 constexpr int kSeqLengthsTensor = 1;
33 constexpr int kOutputTensor = 0;
34 
Prepare(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38 
39   const TfLiteTensor* input;
40   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
41   const TfLiteTensor* seq_lengths;
42   TF_LITE_ENSURE_OK(
43       context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
44   TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
45 
46   if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
47       input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
48       input->type != kTfLiteInt64) {
49     context->ReportError(context,
50                          "Type '%s' is not supported by reverse_sequence.",
51                          TfLiteTypeGetName(input->type));
52     return kTfLiteError;
53   }
54 
55   if (seq_lengths->type != kTfLiteInt32 && seq_lengths->type != kTfLiteInt64) {
56     context->ReportError(
57         context, "Seq_lengths type '%s' is not supported by reverse_sequence.",
58         TfLiteTypeGetName(seq_lengths->type));
59     return kTfLiteError;
60   }
61 
62   TfLiteTensor* output;
63   TF_LITE_ENSURE_OK(context,
64                     GetOutputSafe(context, node, kOutputTensor, &output));
65   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
66   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
67 
68   return context->ResizeTensor(context, output, output_shape);
69 }
70 
71 template <typename T, typename TS>
ReverseSequenceImpl(TfLiteContext * context,TfLiteNode * node)72 TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
73   const TfLiteTensor* input;
74   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
75   const TfLiteTensor* seq_lengths_tensor;
76   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
77                                           &seq_lengths_tensor));
78   const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
79 
80   auto* params =
81       reinterpret_cast<TfLiteReverseSequenceParams*>(node->builtin_data);
82   int seq_dim = params->seq_dim;
83   int batch_dim = params->batch_dim;
84 
85   TF_LITE_ENSURE(context, seq_dim >= 0);
86   TF_LITE_ENSURE(context, batch_dim >= 0);
87   TF_LITE_ENSURE(context, seq_dim != batch_dim);
88   TF_LITE_ENSURE(context, seq_dim < NumDimensions(input));
89   TF_LITE_ENSURE(context, batch_dim < NumDimensions(input));
90   TF_LITE_ENSURE_EQ(context, SizeOfDimension(seq_lengths_tensor, 0),
91                     SizeOfDimension(input, batch_dim));
92   for (int i = 0; i < NumDimensions(seq_lengths_tensor); ++i) {
93     TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
94   }
95 
96   TfLiteTensor* output;
97   TF_LITE_ENSURE_OK(context,
98                     GetOutputSafe(context, node, kOutputTensor, &output));
99 
100   reference_ops::ReverseSequence<T, TS>(
101       seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
102       GetTensorData<T>(input), GetTensorShape(output),
103       GetTensorData<T>(output));
104 
105   return kTfLiteOk;
106 }
107 
108 template <typename T>
ReverseSequenceHelper(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
110   const TfLiteTensor* seq_lengths_tensor;
111   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
112                                           &seq_lengths_tensor));
113   switch (seq_lengths_tensor->type) {
114     case kTfLiteInt32: {
115       return ReverseSequenceImpl<T, int32_t>(context, node);
116     }
117     case kTfLiteInt64: {
118       return ReverseSequenceImpl<T, int64_t>(context, node);
119     }
120     default: {
121       context->ReportError(
122           context,
123           "Seq_lengths type '%s' is not supported by reverse_sequence.",
124           TfLiteTypeGetName(seq_lengths_tensor->type));
125       return kTfLiteError;
126     }
127   }
128   return kTfLiteOk;
129 }
130 
Eval(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132   TfLiteTensor* output;
133   TF_LITE_ENSURE_OK(context,
134                     GetOutputSafe(context, node, kOutputTensor, &output));
135 
136   switch (output->type) {
137     case kTfLiteFloat32: {
138       return ReverseSequenceHelper<float>(context, node);
139     }
140     case kTfLiteUInt8: {
141       return ReverseSequenceHelper<uint8_t>(context, node);
142     }
143     case kTfLiteInt16: {
144       return ReverseSequenceHelper<int16_t>(context, node);
145     }
146     case kTfLiteInt32: {
147       return ReverseSequenceHelper<int32_t>(context, node);
148     }
149     case kTfLiteInt64: {
150       return ReverseSequenceHelper<int64_t>(context, node);
151     }
152     default: {
153       context->ReportError(context,
154                            "Type '%s' is not supported by reverse_sequence.",
155                            TfLiteTypeGetName(output->type));
156       return kTfLiteError;
157     }
158   }
159   return kTfLiteOk;
160 }  // namespace
161 
162 }  // namespace
163 }  // namespace reverse_sequence
164 
Register_REVERSE_SEQUENCE()165 TfLiteRegistration* Register_REVERSE_SEQUENCE() {
166   static TfLiteRegistration r = {nullptr, nullptr, reverse_sequence::Prepare,
167                                  reverse_sequence::Eval};
168   return &r;
169 }
170 
171 }  // namespace builtin
172 }  // namespace ops
173 }  // namespace tflite
174