1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/experimental/kernels/gru_cell.h"
20 #include "tensorflow/lite/kernels/cpu_backend_context.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace experimental {
27 namespace unidirectional_sequence_gru {
28 namespace {
29
GruImpl(const TfLiteTensor * input,const TfLiteTensor * input_state,const TfLiteTensor * gate_weight,const TfLiteTensor * gate_bias,const TfLiteTensor * candidate_weight,const TfLiteTensor * candidate_bias,TfLiteTensor * output,TfLiteTensor * output_state,TfLiteTensor * activation,TfLiteTensor * concat,tflite::CpuBackendContext * cpu_backend_context)30 void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
31 const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
32 const TfLiteTensor* candidate_weight,
33 const TfLiteTensor* candidate_bias, TfLiteTensor* output,
34 TfLiteTensor* output_state, TfLiteTensor* activation,
35 TfLiteTensor* concat,
36 tflite::CpuBackendContext* cpu_backend_context) {
37 const int n_time = input->dims->data[0];
38 const int n_batch = input->dims->data[1];
39 const int n_input = input->dims->data[2];
40 const int n_output = output->dims->data[2];
41 const int n_batch_input = n_batch * n_input;
42 const int n_batch_output = n_batch * n_output;
43 const RuntimeShape input_shape({n_batch, n_input});
44 const float* input_data = GetTensorData<float>(input);
45 const RuntimeShape state_shape = GetTensorShape(input_state);
46 const float* input_state_data = GetTensorData<float>(input_state);
47 const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
48 const float* gate_weight_data = GetTensorData<float>(gate_weight);
49 const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
50 const float* gate_bias_data = GetTensorData<float>(gate_bias);
51 const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
52 const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
53 const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
54 const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
55 const RuntimeShape activation_shape = GetTensorShape(activation);
56 const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
57 float* output_data = GetTensorData<float>(output);
58 float* output_state_data = GetTensorData<float>(output_state);
59 float* activation_data = GetTensorData<float>(activation);
60 const RuntimeShape concat_shape = GetTensorShape(concat);
61 float* concat_data = GetTensorData<float>(concat);
62 tflite::FullyConnectedParams fc_params;
63 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
64 fc_params.float_activation_max = std::numeric_limits<float>::max();
65 for (int i = 0; i < n_time; ++i) {
66 gru_cell::GruCell(
67 input_shape, input_data, state_shape, input_state_data,
68 gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
69 candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
70 candidate_bias_data, output_shape, output_data, output_state_data,
71 activation_shape, activation_data, concat_shape, concat_data, fc_params,
72 cpu_backend_context);
73 input_data += n_batch_input;
74 output_data += n_batch_output;
75 input_state_data = output_state_data;
76 }
77 }
78
79 } // namespace
80
81 enum InputTensor {
82 // Input tensor of size [n_time, n_batch, n_input]
83 kInput = 0,
84 // Input state tensor of size [n_batch, n_output]
85 kInputState = 1,
86 // Gate weight tensor of size [2*n_output, n_input+n_output]
87 kGateWeight = 2,
88 // Gate bias tensor of size [2*n_output]
89 kGateBias = 3,
90 // Candidate weight tensor of size [n_output, n_input+n_output]
91 kCandidateWeight = 4,
92 // Candidate bias tensor of size [n_output]
93 kCandidateBias = 5,
94 kInputNum = 6
95 };
96
97 enum OutputTensor {
98 // Input tensor of size [n_time, n_batch, n_output]
99 kOutput = 0,
100 // Output state tensor of size [n_batch, n_output]
101 kOutputState = 1,
102 kOutputNum = 2
103 };
104
105 enum TemporaryTensor {
106 // Scratch buffer for activation of size [n_batch, 2*n_output]
107 kActivation = 0,
108 // Scratch buffer for activation of size [n_batch, n_input+n_output]
109 kConcat = 1,
110 kTemporaryNum = 2
111 };
112
Init(TfLiteContext * context,const char * buffer,size_t length)113 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
114 auto* scratch_tensor_index = new int;
115 context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
116 return scratch_tensor_index;
117 }
118
Free(TfLiteContext * context,void * buffer)119 void Free(TfLiteContext* context, void* buffer) {
120 delete reinterpret_cast<int*>(buffer);
121 }
122
Prepare(TfLiteContext * context,TfLiteNode * node)123 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
124 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
125
126 TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
127 TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
128
129 // input's dim = [n_time, n_batch, n_input]
130 const TfLiteTensor* input;
131 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
132 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
133 const int n_time = input->dims->data[0];
134 const int n_batch = input->dims->data[1];
135 const int n_input = input->dims->data[2];
136
137 // input_state's dim = [n_batch, n_output]
138 const TfLiteTensor* input_state;
139 TF_LITE_ENSURE_OK(context,
140 GetInputSafe(context, node, kInputState, &input_state));
141 TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
142 TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
143 const int n_output = input_state->dims->data[1];
144
145 // gate_weight' dim = [2 * n_output, n_input + n_output]
146 const TfLiteTensor* gate_weight;
147 TF_LITE_ENSURE_OK(context,
148 GetInputSafe(context, node, kGateWeight, &gate_weight));
149 TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
150 TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
151 TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
152
153 // gate_bias' dim = [2 * n_output]
154 const TfLiteTensor* gate_bias;
155 TF_LITE_ENSURE_OK(context,
156 GetInputSafe(context, node, kGateBias, &gate_bias));
157 TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
158 TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
159
160 // candidate_weight' dim = [n_output, n_input + n_output]
161 const TfLiteTensor* candidate_weight;
162 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
163 &candidate_weight));
164 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
165 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
166 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
167 n_input + n_output);
168
169 // candidate_bias' dim = [n_output]
170 const TfLiteTensor* candidate_bias;
171 TF_LITE_ENSURE_OK(
172 context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
173 TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
174 TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
175
176 // output's dim = [n_time, n_batch, n_output]
177 TfLiteTensor* output;
178 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
179 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
180 output_size->data[0] = n_time;
181 output_size->data[1] = n_batch;
182 output_size->data[2] = n_output;
183 TF_LITE_ENSURE_OK(context,
184 context->ResizeTensor(context, output, output_size));
185
186 // output_state's dim = [n_batch, n_output]
187 TfLiteTensor* output_state;
188 TF_LITE_ENSURE_OK(context,
189 GetOutputSafe(context, node, kOutputState, &output_state));
190 TF_LITE_ENSURE_OK(
191 context, context->ResizeTensor(context, output_state,
192 TfLiteIntArrayCopy(input_state->dims)));
193
194 TfLiteIntArrayFree(node->temporaries);
195 node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
196
197 // activation's dim = [n_batch, 2 * n_output]
198 node->temporaries->data[kActivation] = *scratch_tensor_index;
199 TfLiteTensor* activation;
200 TF_LITE_ENSURE_OK(context,
201 GetTemporarySafe(context, node, kActivation, &activation));
202 activation->type = input->type;
203 activation->allocation_type = kTfLiteArenaRw;
204 TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
205 activation_size->data[0] = n_batch;
206 activation_size->data[1] = 2 * n_output;
207 TF_LITE_ENSURE_OK(
208 context, context->ResizeTensor(context, activation, activation_size));
209
210 // concat's dim = [n_batch, n_input + n_output]
211 node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
212 TfLiteTensor* concat;
213 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
214 concat->type = input->type;
215 concat->allocation_type = kTfLiteArenaRw;
216 TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
217 concat_size->data[0] = n_batch;
218 concat_size->data[1] = n_input + n_output;
219 TF_LITE_ENSURE_OK(context,
220 context->ResizeTensor(context, concat, concat_size));
221
222 return kTfLiteOk;
223 }
224
Eval(TfLiteContext * context,TfLiteNode * node)225 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
226 const TfLiteTensor* input;
227 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
228 const TfLiteTensor* input_state;
229 TF_LITE_ENSURE_OK(context,
230 GetInputSafe(context, node, kInputState, &input_state));
231 const TfLiteTensor* gate_weight;
232 TF_LITE_ENSURE_OK(context,
233 GetInputSafe(context, node, kGateWeight, &gate_weight));
234 const TfLiteTensor* gate_bias;
235 TF_LITE_ENSURE_OK(context,
236 GetInputSafe(context, node, kGateBias, &gate_bias));
237 const TfLiteTensor* candidate_weight;
238 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
239 &candidate_weight));
240 const TfLiteTensor* candidate_bias;
241 TF_LITE_ENSURE_OK(
242 context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
243 TfLiteTensor* output;
244 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
245 TfLiteTensor* output_state;
246 TF_LITE_ENSURE_OK(context,
247 GetOutputSafe(context, node, kOutputState, &output_state));
248 TfLiteTensor* activation;
249 TF_LITE_ENSURE_OK(context,
250 GetTemporarySafe(context, node, kActivation, &activation));
251 TfLiteTensor* concat;
252 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
253 auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
254
255 if (gate_weight->type == kTfLiteFloat32) {
256 GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
257 candidate_bias, output, output_state, activation, concat,
258 cpu_backend_context);
259 } else {
260 context->ReportError(context,
261 "Unsupported combination of data types for GruCell");
262 return kTfLiteError;
263 }
264
265 return kTfLiteOk;
266 }
267
268 } // namespace unidirectional_sequence_gru
269
Register_UNIDIRECTIONAL_SEQUENCE_GRU()270 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
271 static TfLiteRegistration r = {
272 unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
273 unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
274 return &r;
275 }
276
277 } // namespace experimental
278 } // namespace ops
279 } // namespace tflite
280