1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace reverse_sequence {
29 namespace {
30
31 constexpr int kInputTensor = 0;
32 constexpr int kSeqLengthsTensor = 1;
33 constexpr int kOutputTensor = 0;
34
Prepare(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38
39 const TfLiteTensor* input;
40 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
41 const TfLiteTensor* seq_lengths;
42 TF_LITE_ENSURE_OK(
43 context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
44 TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
45
46 if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
47 input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
48 input->type != kTfLiteInt64) {
49 context->ReportError(context,
50 "Type '%s' is not supported by reverse_sequence.",
51 TfLiteTypeGetName(input->type));
52 return kTfLiteError;
53 }
54
55 if (seq_lengths->type != kTfLiteInt32 && seq_lengths->type != kTfLiteInt64) {
56 context->ReportError(
57 context, "Seq_lengths type '%s' is not supported by reverse_sequence.",
58 TfLiteTypeGetName(seq_lengths->type));
59 return kTfLiteError;
60 }
61
62 TfLiteTensor* output;
63 TF_LITE_ENSURE_OK(context,
64 GetOutputSafe(context, node, kOutputTensor, &output));
65 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
66 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
67
68 return context->ResizeTensor(context, output, output_shape);
69 }
70
71 template <typename T, typename TS>
ReverseSequenceImpl(TfLiteContext * context,TfLiteNode * node)72 TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
73 const TfLiteTensor* input;
74 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
75 const TfLiteTensor* seq_lengths_tensor;
76 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
77 &seq_lengths_tensor));
78 const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
79
80 auto* params =
81 reinterpret_cast<TfLiteReverseSequenceParams*>(node->builtin_data);
82 int seq_dim = params->seq_dim;
83 int batch_dim = params->batch_dim;
84
85 TF_LITE_ENSURE(context, seq_dim >= 0);
86 TF_LITE_ENSURE(context, batch_dim >= 0);
87 TF_LITE_ENSURE(context, seq_dim != batch_dim);
88 TF_LITE_ENSURE(context, seq_dim < NumDimensions(input));
89 TF_LITE_ENSURE(context, batch_dim < NumDimensions(input));
90 TF_LITE_ENSURE_EQ(context, SizeOfDimension(seq_lengths_tensor, 0),
91 SizeOfDimension(input, batch_dim));
92 for (int i = 0; i < NumDimensions(seq_lengths_tensor); ++i) {
93 TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
94 }
95
96 TfLiteTensor* output;
97 TF_LITE_ENSURE_OK(context,
98 GetOutputSafe(context, node, kOutputTensor, &output));
99
100 reference_ops::ReverseSequence<T, TS>(
101 seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
102 GetTensorData<T>(input), GetTensorShape(output),
103 GetTensorData<T>(output));
104
105 return kTfLiteOk;
106 }
107
108 template <typename T>
ReverseSequenceHelper(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
110 const TfLiteTensor* seq_lengths_tensor;
111 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
112 &seq_lengths_tensor));
113 switch (seq_lengths_tensor->type) {
114 case kTfLiteInt32: {
115 return ReverseSequenceImpl<T, int32_t>(context, node);
116 }
117 case kTfLiteInt64: {
118 return ReverseSequenceImpl<T, int64_t>(context, node);
119 }
120 default: {
121 context->ReportError(
122 context,
123 "Seq_lengths type '%s' is not supported by reverse_sequence.",
124 TfLiteTypeGetName(seq_lengths_tensor->type));
125 return kTfLiteError;
126 }
127 }
128 return kTfLiteOk;
129 }
130
Eval(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132 TfLiteTensor* output;
133 TF_LITE_ENSURE_OK(context,
134 GetOutputSafe(context, node, kOutputTensor, &output));
135
136 switch (output->type) {
137 case kTfLiteFloat32: {
138 return ReverseSequenceHelper<float>(context, node);
139 }
140 case kTfLiteUInt8: {
141 return ReverseSequenceHelper<uint8_t>(context, node);
142 }
143 case kTfLiteInt16: {
144 return ReverseSequenceHelper<int16_t>(context, node);
145 }
146 case kTfLiteInt32: {
147 return ReverseSequenceHelper<int32_t>(context, node);
148 }
149 case kTfLiteInt64: {
150 return ReverseSequenceHelper<int64_t>(context, node);
151 }
152 default: {
153 context->ReportError(context,
154 "Type '%s' is not supported by reverse_sequence.",
155 TfLiteTypeGetName(output->type));
156 return kTfLiteError;
157 }
158 }
159 return kTfLiteOk;
160 } // namespace
161
162 } // namespace
163 } // namespace reverse_sequence
164
Register_REVERSE_SEQUENCE()165 TfLiteRegistration* Register_REVERSE_SEQUENCE() {
166 static TfLiteRegistration r = {nullptr, nullptr, reverse_sequence::Prepare,
167 reverse_sequence::Eval};
168 return &r;
169 }
170
171 } // namespace builtin
172 } // namespace ops
173 } // namespace tflite
174