1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstddef>
17 #include <cstdint>
18
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace bidirectional_sequence_rnn {
30
31 namespace {
32
33 struct OpData {
34 int scratch_tensor_index;
35 bool fw_compute_row_sums = false;
36 bool bw_compute_row_sums = false;
37 };
38
39 } // namespace
40
41 // LINT.IfChange
42
43 constexpr int kInputTensor = 0;
44 // Forward and backward cell tensors.
45 constexpr int kFwWeightsTensor = 1;
46 constexpr int kFwRecurrentWeightsTensor = 2;
47 constexpr int kFwBiasTensor = 3;
48 constexpr int kFwHiddenStateTensor = 4;
49 constexpr int kBwWeightsTensor = 5;
50 constexpr int kBwRecurrentWeightsTensor = 6;
51 constexpr int kBwBiasTensor = 7;
52 constexpr int kBwHiddenStateTensor = 8;
53 // Used as auxiliary input and weights when stacking for
54 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
55 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
56 // (without cross links).
57 constexpr int kAuxInputTensor = 9; // Optional.
58 constexpr int kFwAuxWeightsTensor = 10; // Optional.
59 constexpr int kBwAuxWeightsTensor = 11; // Optional.
60 // Output tensors.
61 constexpr int kFwOutputTensor = 0;
62 constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
63
64 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
65
66 // Temporary tensors.
67 enum TemporaryTensor {
68 kInputQuantized = 0,
69 kFwHiddenStateQuantized = 1,
70 kBwHiddenStateQuantized = 2,
71 kScalingFactors = 3,
72 kAccumScratch = 4,
73 kZeroPoints = 5,
74 kFwRowSums = 6,
75 kBwRowSums = 7,
76 kAuxInputQuantized = 8,
77 kNumTemporaryTensors = 9
78 };
79
Init(TfLiteContext * context,const char * buffer,size_t length)80 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81 auto* op_data = new OpData();
82 context->AddTensors(context, kNumTemporaryTensors,
83 &op_data->scratch_tensor_index);
84 return op_data;
85 }
86
Free(TfLiteContext * context,void * buffer)87 void Free(TfLiteContext* context, void* buffer) {
88 delete reinterpret_cast<OpData*>(buffer);
89 }
90
Prepare(TfLiteContext * context,TfLiteNode * node)91 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
92 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
93 node->builtin_data);
94
95 // Check we have all the inputs and outputs we need.
96 TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
97 TF_LITE_ENSURE_EQ(context, node->outputs->size,
98 params->merge_outputs ? 1 : 2);
99
100 const TfLiteTensor* input;
101 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
102 const TfLiteTensor* fw_input_weights;
103 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
104 &fw_input_weights));
105 const TfLiteTensor* fw_recurrent_weights;
106 TF_LITE_ENSURE_OK(context,
107 GetInputSafe(context, node, kFwRecurrentWeightsTensor,
108 &fw_recurrent_weights));
109 const TfLiteTensor* fw_bias;
110 TF_LITE_ENSURE_OK(context,
111 GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
112 const TfLiteTensor* fw_hidden_state;
113 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
114 &fw_hidden_state));
115 const TfLiteTensor* bw_input_weights;
116 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
117 &bw_input_weights));
118 const TfLiteTensor* bw_recurrent_weights;
119 TF_LITE_ENSURE_OK(context,
120 GetInputSafe(context, node, kBwRecurrentWeightsTensor,
121 &bw_recurrent_weights));
122 const TfLiteTensor* bw_bias;
123 TF_LITE_ENSURE_OK(context,
124 GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
125 const TfLiteTensor* bw_hidden_state;
126 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
127 &bw_hidden_state));
128
129 const TfLiteTensor* aux_input =
130 GetOptionalInputTensor(context, node, kAuxInputTensor);
131 const TfLiteTensor* fw_aux_input_weights =
132 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
133 const TfLiteTensor* bw_aux_input_weights =
134 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
135
136 const bool aux_inputs_weights_or_none =
137 ((fw_aux_input_weights != nullptr) &&
138 (bw_aux_input_weights != nullptr)) ||
139 ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr));
140 TF_LITE_ENSURE(context, aux_inputs_weights_or_none);
141 const bool has_aux_input = (fw_aux_input_weights != nullptr);
142
143 // Check all the parameters of tensor match within themselves and match the
144 // input configuration.
145 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
146
147 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
148 const bool time_major = params->time_major;
149 const int batch_size =
150 (time_major) ? input->dims->data[1] : input->dims->data[0];
151 const int max_time =
152 (time_major) ? input->dims->data[0] : input->dims->data[1];
153 const int fw_num_units = fw_input_weights->dims->data[0];
154 const int bw_num_units = bw_input_weights->dims->data[0];
155 TF_LITE_ENSURE_EQ(context, input->dims->data[2],
156 fw_input_weights->dims->data[1]);
157 TF_LITE_ENSURE_EQ(context, input->dims->data[2],
158 bw_input_weights->dims->data[1]);
159 TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0],
160 fw_bias->dims->data[0]);
161 TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0],
162 bw_bias->dims->data[0]);
163 TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0],
164 fw_bias->dims->data[0]);
165 TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1],
166 bw_bias->dims->data[0]);
167 TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
168 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
169 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
170 TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
171 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
172 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
173
174 if (has_aux_input) {
175 // Check that aux_input has the same dimensions (except last) as the input.
176 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
177 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
178 // Check that aux_input_weights has the same dimensions (except last) as
179 // the input_weights.
180 TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
181 TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
182 TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
183 fw_aux_input_weights->dims->data[1]);
184 TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
185 bw_aux_input_weights->dims->data[1]);
186 }
187
188 if (IsHybridOp(input, fw_input_weights)) {
189 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
190 op_data->fw_compute_row_sums = true;
191 op_data->bw_compute_row_sums = true;
192 TfLiteIntArrayFree(node->temporaries);
193 if (has_aux_input) {
194 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
195 } else {
196 // No need to create a temporary tensor for the non-existent aux_input.
197 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
198 }
199
200 node->temporaries->data[kInputQuantized] =
201 op_data->scratch_tensor_index + kInputQuantized;
202 TfLiteTensor* input_quantized;
203 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
204 &input_quantized));
205 input_quantized->type = fw_input_weights->type;
206 input_quantized->allocation_type = kTfLiteArenaRw;
207 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
208 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
209 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
210 input_quantized_size));
211 }
212
213 node->temporaries->data[kFwHiddenStateQuantized] =
214 op_data->scratch_tensor_index + kFwHiddenStateQuantized;
215 TfLiteTensor* fw_hidden_state_quantized;
216 TF_LITE_ENSURE_OK(context,
217 GetTemporarySafe(context, node, kFwHiddenStateQuantized,
218 &fw_hidden_state_quantized));
219 fw_hidden_state_quantized->type = fw_input_weights->type;
220 fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
221 if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
222 fw_hidden_state->dims)) {
223 TfLiteIntArray* fw_hidden_state_quantized_size =
224 TfLiteIntArrayCopy(fw_hidden_state->dims);
225 TF_LITE_ENSURE_OK(
226 context, context->ResizeTensor(context, fw_hidden_state_quantized,
227 fw_hidden_state_quantized_size));
228 }
229
230 node->temporaries->data[kBwHiddenStateQuantized] =
231 op_data->scratch_tensor_index + kBwHiddenStateQuantized;
232 TfLiteTensor* bw_hidden_state_quantized;
233 TF_LITE_ENSURE_OK(context,
234 GetTemporarySafe(context, node, kBwHiddenStateQuantized,
235 &bw_hidden_state_quantized));
236 bw_hidden_state_quantized->type = fw_input_weights->type;
237 bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
238 if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
239 bw_hidden_state->dims)) {
240 TfLiteIntArray* bw_hidden_state_quantized_size =
241 TfLiteIntArrayCopy(bw_hidden_state->dims);
242 TF_LITE_ENSURE_OK(
243 context, context->ResizeTensor(context, bw_hidden_state_quantized,
244 bw_hidden_state_quantized_size));
245 }
246
247 // Allocate temporary tensors to store scaling factors of quantization.
248 node->temporaries->data[kScalingFactors] =
249 op_data->scratch_tensor_index + kScalingFactors;
250 TfLiteTensor* scaling_factors;
251 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
252 &scaling_factors));
253 scaling_factors->type = kTfLiteFloat32;
254 scaling_factors->allocation_type = kTfLiteArenaRw;
255 int scaling_dims[1] = {batch_size};
256 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
257 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
258 scaling_factors_size->data[0] = batch_size;
259 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
260 scaling_factors_size));
261 }
262 node->temporaries->data[kAccumScratch] =
263 op_data->scratch_tensor_index + kAccumScratch;
264 TfLiteTensor* accum_scratch;
265 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
266 &accum_scratch));
267 accum_scratch->type = kTfLiteInt32;
268 accum_scratch->allocation_type = kTfLiteArenaRw;
269 int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
270 batch_size};
271 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
272 accum_scratch_dims)) {
273 TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
274 accum_scratch_size->data[0] = accum_scratch_dims[0];
275 accum_scratch_size->data[1] = accum_scratch_dims[1];
276 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
277 accum_scratch_size));
278 }
279 node->temporaries->data[kZeroPoints] =
280 op_data->scratch_tensor_index + kZeroPoints;
281 TfLiteTensor* zero_points;
282 TF_LITE_ENSURE_OK(
283 context,
284 GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
285 zero_points->type = kTfLiteInt32;
286 zero_points->allocation_type = kTfLiteArenaRw;
287 int zero_points_dims[1] = {batch_size};
288 if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
289 TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
290 zero_points_size->data[0] = batch_size;
291 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
292 zero_points_size));
293 }
294 const int num_row_sums = has_aux_input ? 3 : 2;
295 node->temporaries->data[kFwRowSums] =
296 op_data->scratch_tensor_index + kFwRowSums;
297 TfLiteTensor* fw_row_sums;
298 TF_LITE_ENSURE_OK(
299 context,
300 GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
301 fw_row_sums->type = kTfLiteInt32;
302 fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
303 int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
304 if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
305 TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
306 fw_row_sums_size->data[0] = fw_row_sums_dims[0];
307 fw_row_sums_size->data[1] = fw_row_sums_dims[1];
308 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
309 fw_row_sums_size));
310 }
311 node->temporaries->data[kBwRowSums] =
312 op_data->scratch_tensor_index + kBwRowSums;
313 TfLiteTensor* bw_row_sums;
314 TF_LITE_ENSURE_OK(
315 context,
316 GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
317 bw_row_sums->type = kTfLiteInt32;
318 bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
319 int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
320 if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
321 TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
322 bw_row_sums_size->data[0] = bw_row_sums_dims[0];
323 bw_row_sums_size->data[1] = bw_row_sums_dims[1];
324 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
325 bw_row_sums_size));
326 }
327 if (has_aux_input) {
328 node->temporaries->data[kAuxInputQuantized] =
329 op_data->scratch_tensor_index + kAuxInputQuantized;
330 TfLiteTensor* aux_input_quantized;
331 TF_LITE_ENSURE_OK(context,
332 GetTemporarySafe(context, node, kAuxInputQuantized,
333 &aux_input_quantized));
334 aux_input_quantized->type = fw_input_weights->type;
335 aux_input_quantized->allocation_type = kTfLiteArenaRw;
336 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
337 TfLiteIntArray* aux_input_quantized_size =
338 TfLiteIntArrayCopy(aux_input->dims);
339 TF_LITE_ENSURE_OK(context,
340 context->ResizeTensor(context, aux_input_quantized,
341 aux_input_quantized_size));
342 }
343 }
344 }
345
346 // Resize outputs.
347 TfLiteTensor* fw_output;
348 TF_LITE_ENSURE_OK(context,
349 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
350 TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
351 fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
352 fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
353 fw_output_size_array->data[2] =
354 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
355 TF_LITE_ENSURE_OK(
356 context, context->ResizeTensor(context, fw_output, fw_output_size_array));
357 if (!params->merge_outputs) {
358 TfLiteTensor* bw_output;
359 TF_LITE_ENSURE_OK(
360 context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
361 TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
362 bw_output_size_array->data[0] = batch_size;
363 bw_output_size_array->data[1] = max_time;
364 bw_output_size_array->data[2] = bw_num_units;
365 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
366 bw_output_size_array));
367 }
368
369 return kTfLiteOk;
370 }
371
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * fw_aux_input_weights,const TfLiteTensor * bw_aux_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output)372 TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
373 const TfLiteTensor* fw_input_weights,
374 const TfLiteTensor* fw_recurrent_weights,
375 const TfLiteTensor* fw_bias,
376 const TfLiteTensor* bw_input_weights,
377 const TfLiteTensor* bw_recurrent_weights,
378 const TfLiteTensor* bw_bias,
379 const TfLiteTensor* aux_input,
380 const TfLiteTensor* fw_aux_input_weights,
381 const TfLiteTensor* bw_aux_input_weights,
382 const TfLiteBidirectionalSequenceRNNParams* params,
383 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
384 TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
385 const bool time_major = params->time_major;
386 const int batch_size =
387 (time_major) ? input->dims->data[1] : input->dims->data[0];
388 const int max_time =
389 (time_major) ? input->dims->data[0] : input->dims->data[1];
390 const int input_size = input->dims->data[2];
391 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
392
393 const int fw_num_units = fw_input_weights->dims->data[0];
394 const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
395 const float* fw_input_weights_ptr = GetTensorData<float>(fw_input_weights);
396 const float* fw_recurrent_weights_ptr =
397 GetTensorData<float>(fw_recurrent_weights);
398
399 const int bw_num_units = bw_input_weights->dims->data[0];
400 const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
401 const float* bw_input_weights_ptr = GetTensorData<float>(bw_input_weights);
402 const float* bw_recurrent_weights_ptr =
403 GetTensorData<float>(bw_recurrent_weights);
404
405 const float* fw_aux_input_weights_ptr =
406 (fw_aux_input_weights != nullptr)
407 ? GetTensorData<float>(fw_aux_input_weights)
408 : nullptr;
409 const float* bw_aux_input_weights_ptr =
410 (bw_aux_input_weights != nullptr)
411 ? GetTensorData<float>(bw_aux_input_weights)
412 : nullptr;
413
414 const int fw_output_step =
415 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
416 const int bw_output_step =
417 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
418 if (time_major) {
419 // Forward cell.
420 float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
421 for (int s = 0; s < max_time; s++) {
422 const float* input_ptr_batch =
423 GetTensorData<float>(input) + s * input_size * batch_size;
424 const float* aux_input_ptr_batch =
425 (aux_input != nullptr)
426 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
427 : nullptr;
428 float* output_ptr_batch =
429 GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
430
431 kernel_utils::RnnBatchStep(
432 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
433 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
434 input_size, aux_input_size, fw_num_units, batch_size, fw_output_step,
435 params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
436 }
437 // Backward cell.
438 float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
439 for (int s = max_time - 1; s >= 0; s--) {
440 const float* input_ptr_batch =
441 GetTensorData<float>(bw_input) + s * input_size * batch_size;
442 const float* aux_input_ptr_batch =
443 (aux_input != nullptr)
444 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
445 : nullptr;
446 float* output_ptr_batch =
447 (params->merge_outputs
448 ? GetTensorData<float>(fw_output) + fw_num_units
449 : GetTensorData<float>(bw_output)) +
450 s * bw_output_step * batch_size;
451
452 kernel_utils::RnnBatchStep(
453 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
454 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
455 input_size, aux_input_size, bw_num_units, batch_size, bw_output_step,
456 params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
457 }
458 } else {
459 for (int b = 0; b < batch_size; b++) {
460 // Forward cell.
461 float* fw_hidden_state_ptr_batch =
462 GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
463 float* fw_output_offset =
464 GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
465 for (int s = 0; s < max_time; s++) {
466 const float* input_ptr_batch = GetTensorData<float>(input) +
467 b * input_size * max_time +
468 s * input_size;
469 const float* aux_input_ptr_batch =
470 (aux_input != nullptr)
471 ? GetTensorData<float>(aux_input) +
472 b * aux_input_size * max_time + s * aux_input_size
473 : nullptr;
474 float* output_ptr_batch = fw_output_offset + s * fw_output_step;
475
476 kernel_utils::RnnBatchStep(
477 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
478 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
479 input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
480 fw_output_step, params->activation, fw_hidden_state_ptr_batch,
481 output_ptr_batch);
482 }
483 // Backward cell.
484 float* bw_hidden_state_ptr_batch =
485 GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
486 float* bw_output_offset =
487 params->merge_outputs
488 ? GetTensorData<float>(fw_output) +
489 b * bw_output_step * max_time + fw_num_units
490 : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
491 for (int s = max_time - 1; s >= 0; s--) {
492 const float* input_ptr_batch = GetTensorData<float>(input) +
493 b * input_size * max_time +
494 s * input_size;
495 const float* aux_input_ptr_batch =
496 (aux_input != nullptr)
497 ? GetTensorData<float>(aux_input) +
498 b * aux_input_size * max_time + s * aux_input_size
499 : nullptr;
500 float* output_ptr_batch = bw_output_offset + s * bw_output_step;
501
502 kernel_utils::RnnBatchStep(
503 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
504 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
505 input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
506 bw_output_step, params->activation, bw_hidden_state_ptr_batch,
507 output_ptr_batch);
508 }
509 }
510 }
511 return kTfLiteOk;
512 }
513
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * aux_fw_input_weights,const TfLiteTensor * aux_bw_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * scaling_factors,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * fw_hidden_state_quantized,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state_quantized,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output,TfLiteTensor * zero_points,TfLiteTensor * accum_scratch,TfLiteTensor * fw_row_sums,TfLiteTensor * bw_row_sums,bool * fw_compute_row_sums,bool * bw_compute_row_sums)514 TfLiteStatus EvalHybrid(
515 const TfLiteTensor* input, const TfLiteTensor* bw_input,
516 const TfLiteTensor* fw_input_weights,
517 const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
518 const TfLiteTensor* bw_input_weights,
519 const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
520 const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
521 const TfLiteTensor* aux_bw_input_weights,
522 const TfLiteBidirectionalSequenceRNNParams* params,
523 TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
524 TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
525 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
526 TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
527 TfLiteTensor* bw_output, TfLiteTensor* zero_points,
528 TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
529 TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
530 bool* bw_compute_row_sums) {
531 const bool time_major = params->time_major;
532 const int batch_size =
533 (time_major) ? input->dims->data[1] : input->dims->data[0];
534 const int max_time =
535 (time_major) ? input->dims->data[0] : input->dims->data[1];
536 const int input_size = input->dims->data[2];
537 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
538
539 const int fw_num_units = fw_input_weights->dims->data[0];
540 const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
541 const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights);
542 float fw_input_weights_scale = fw_input_weights->params.scale;
543 const int8_t* fw_recurrent_weights_ptr =
544 GetTensorData<int8_t>(fw_recurrent_weights);
545 float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
546
547 const int bw_num_units = bw_input_weights->dims->data[0];
548 const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
549 const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights);
550 float bw_input_weights_scale = bw_input_weights->params.scale;
551 const int8_t* bw_recurrent_weights_ptr =
552 GetTensorData<int8_t>(bw_recurrent_weights);
553 float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
554
555 // Set the auxiliary pointers and scales if needed.
556 const int8_t* aux_fw_input_weights_ptr = nullptr;
557 float aux_fw_input_weights_scale = 0.0f;
558 const int8_t* aux_bw_input_weights_ptr = nullptr;
559 float aux_bw_input_weights_scale = 0.0f;
560 int8_t* aux_quantized_input_ptr = nullptr;
561 if (aux_input_size > 0) {
562 aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights);
563 aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
564 aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights);
565 aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
566 aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized);
567 }
568
569 // Initialize temporary storage for quantized values.
570 int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
571 int8_t* fw_quantized_hidden_state_ptr =
572 GetTensorData<int8_t>(fw_hidden_state_quantized);
573 int8_t* bw_quantized_hidden_state_ptr =
574 GetTensorData<int8_t>(bw_hidden_state_quantized);
575 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
576 int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
577 int32_t* zero_points_ptr = nullptr;
578 int32_t* fw_row_sums_ptr = nullptr;
579 int32_t* bw_row_sums_ptr = nullptr;
580 if (params->asymmetric_quantize_inputs) {
581 zero_points_ptr = GetTensorData<int32_t>(zero_points);
582 fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
583 bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
584 }
585 const int fw_output_step =
586 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
587 const int bw_output_step =
588 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
589
590 if (time_major) {
591 for (int t = 0; t < max_time; t++) {
592 // Forward cell.
593 float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
594 for (int s = 0; s < max_time; s++) {
595 const float* input_ptr_batch =
596 GetTensorData<float>(input) + s * input_size * batch_size;
597 const float* aux_input_ptr_batch =
598 (aux_input != nullptr)
599 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
600 : nullptr;
601 float* output_ptr_batch =
602 GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
603
604 kernel_utils::RnnBatchStep(
605 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
606 aux_input_ptr_batch, aux_fw_input_weights_ptr,
607 aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
608 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
609 fw_num_units, batch_size, fw_output_step, params->activation,
610 quantized_input_ptr, aux_quantized_input_ptr,
611 fw_quantized_hidden_state_ptr, scaling_factors_ptr,
612 fw_hidden_state_ptr_batch, output_ptr_batch,
613 params->asymmetric_quantize_inputs, zero_points_ptr,
614 accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
615 }
616 // Backward cell.
617 float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
618 for (int s = max_time - 1; s >= 0; s--) {
619 const float* input_ptr_batch =
620 GetTensorData<float>(bw_input) + s * input_size * batch_size;
621 const float* aux_input_ptr_batch =
622 (aux_input != nullptr)
623 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
624 : nullptr;
625 float* output_ptr_batch =
626 (params->merge_outputs
627 ? GetTensorData<float>(fw_output) + fw_num_units
628 : GetTensorData<float>(bw_output)) +
629 s * bw_output_step * batch_size;
630
631 kernel_utils::RnnBatchStep(
632 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
633 aux_input_ptr_batch, aux_bw_input_weights_ptr,
634 aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
635 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
636 bw_num_units, batch_size, bw_output_step, params->activation,
637 quantized_input_ptr, aux_quantized_input_ptr,
638 bw_quantized_hidden_state_ptr, scaling_factors_ptr,
639 bw_hidden_state_ptr_batch, output_ptr_batch,
640 params->asymmetric_quantize_inputs, zero_points_ptr,
641 accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
642 }
643 }
644 } else {
645 for (int b = 0; b < batch_size; b++) {
646 // Forward cell.
647 float* fw_hidden_state_ptr_batch =
648 GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
649 float* fw_output_offset =
650 GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
651 for (int s = 0; s < max_time; s++) {
652 const float* input_ptr_batch = GetTensorData<float>(input) +
653 b * input_size * max_time +
654 s * input_size;
655 const float* aux_input_ptr_batch =
656 (aux_input != nullptr)
657 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
658 s * input_size
659 : nullptr;
660 float* output_ptr_batch = fw_output_offset + s * fw_output_step;
661
662 kernel_utils::RnnBatchStep(
663 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
664 aux_input_ptr_batch, aux_fw_input_weights_ptr,
665 aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
666 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
667 fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
668 quantized_input_ptr, aux_quantized_input_ptr,
669 fw_quantized_hidden_state_ptr, scaling_factors_ptr,
670 fw_hidden_state_ptr_batch, output_ptr_batch,
671 params->asymmetric_quantize_inputs, zero_points_ptr,
672 accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
673 }
674 // Backward cell.
675 float* bw_hidden_state_ptr_batch =
676 GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
677 float* bw_output_offset =
678 params->merge_outputs
679 ? GetTensorData<float>(fw_output) +
680 b * bw_output_step * max_time + fw_num_units
681 : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
682 for (int s = max_time - 1; s >= 0; s--) {
683 const float* input_ptr_batch = GetTensorData<float>(input) +
684 b * input_size * max_time +
685 s * input_size;
686 const float* aux_input_ptr_batch =
687 (aux_input != nullptr)
688 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
689 s * input_size
690 : nullptr;
691 float* output_ptr_batch = bw_output_offset + s * bw_output_step;
692
693 kernel_utils::RnnBatchStep(
694 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
695 aux_input_ptr_batch, aux_bw_input_weights_ptr,
696 aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
697 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
698 bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
699 quantized_input_ptr, aux_quantized_input_ptr,
700 bw_quantized_hidden_state_ptr, scaling_factors_ptr,
701 bw_hidden_state_ptr_batch, output_ptr_batch,
702 params->asymmetric_quantize_inputs, zero_points_ptr,
703 accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
704 }
705 }
706 }
707 return kTfLiteOk;
708 }
709
Eval(TfLiteContext * context,TfLiteNode * node)710 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
711 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
712 node->builtin_data);
713
714 const TfLiteTensor* input;
715 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
716 const TfLiteTensor* fw_input_weights;
717 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
718 &fw_input_weights));
719 const TfLiteTensor* fw_recurrent_weights;
720 TF_LITE_ENSURE_OK(context,
721 GetInputSafe(context, node, kFwRecurrentWeightsTensor,
722 &fw_recurrent_weights));
723 const TfLiteTensor* fw_bias;
724 TF_LITE_ENSURE_OK(context,
725 GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
726 const TfLiteTensor* bw_input_weights;
727 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
728 &bw_input_weights));
729 const TfLiteTensor* bw_recurrent_weights;
730 TF_LITE_ENSURE_OK(context,
731 GetInputSafe(context, node, kBwRecurrentWeightsTensor,
732 &bw_recurrent_weights));
733 const TfLiteTensor* bw_bias;
734 TF_LITE_ENSURE_OK(context,
735 GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
736
737 // Get auxiliary inputs.
738 const TfLiteTensor* aux_input =
739 GetOptionalInputTensor(context, node, kAuxInputTensor);
740 const TfLiteTensor* fw_aux_input_weights =
741 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
742 const TfLiteTensor* bw_aux_input_weights =
743 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
744
745 TfLiteTensor* fw_hidden_state =
746 GetVariableInput(context, node, kFwHiddenStateTensor);
747 TFLITE_DCHECK(fw_hidden_state != nullptr);
748 TfLiteTensor* bw_hidden_state =
749 GetVariableInput(context, node, kBwHiddenStateTensor);
750 TFLITE_DCHECK(bw_hidden_state != nullptr);
751
752 TfLiteTensor* fw_output;
753 TF_LITE_ENSURE_OK(context,
754 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
755 TfLiteTensor* bw_output = params->merge_outputs
756 ? nullptr
757 : GetOutput(context, node, kBwOutputTensor);
758
759 const bool has_previous_bw_output = (aux_input != nullptr);
760 const bool use_aux_input = (fw_aux_input_weights != nullptr);
761
762 // We want to cover the following cases:
763 //
764 // If not stacking (not connected after other bidi lstms):
765 // both fw & bw will just use `input`; aux_input will be null.
766 //
767 // If stacking with cross_links, TensorFlow equivalent
768 // (tf.contrib.rnn.stack_bidirectional_rnn):
769 // both fw & bw will use `input`, but aux_input will be none null.
770 // Note, this time, whether connected after other bidi lstms both works.
771 //
772 // If stacking without cross_links, but connected after other bidi lstms,
773 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
774 // fw will use `input`, bw will use aux_input, and the `real aux_input`
775 // will be null.
776
777 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
778 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
779 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
780
781 switch (fw_input_weights->type) {
782 case kTfLiteFloat32:
783 return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights,
784 fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias,
785 real_aux_input, fw_aux_input_weights,
786 bw_aux_input_weights, params, fw_hidden_state, fw_output,
787 bw_hidden_state, bw_output);
788 case kTfLiteUInt8:
789 case kTfLiteInt8: {
790 TfLiteTensor* input_quantized;
791 TF_LITE_ENSURE_OK(
792 context,
793 GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
794 TfLiteTensor* fw_hidden_state_quantized;
795 TF_LITE_ENSURE_OK(context,
796 GetTemporarySafe(context, node, kFwHiddenStateQuantized,
797 &fw_hidden_state_quantized));
798 TfLiteTensor* bw_hidden_state_quantized;
799 TF_LITE_ENSURE_OK(context,
800 GetTemporarySafe(context, node, kBwHiddenStateQuantized,
801 &bw_hidden_state_quantized));
802 TfLiteTensor* scaling_factors;
803 TF_LITE_ENSURE_OK(
804 context,
805 GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
806 TfLiteTensor* zero_points;
807 TF_LITE_ENSURE_OK(
808 context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
809 TfLiteTensor* accum_scratch;
810 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
811 &accum_scratch));
812 TfLiteTensor* fw_row_sums;
813 TF_LITE_ENSURE_OK(
814 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
815 TfLiteTensor* bw_row_sums;
816 TF_LITE_ENSURE_OK(
817 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
818 TfLiteTensor* aux_input_quantized =
819 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
820 : nullptr;
821 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
822 return EvalHybrid(
823 input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
824 bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
825 fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
826 input_quantized, aux_input_quantized, fw_hidden_state_quantized,
827 fw_hidden_state, fw_output, bw_hidden_state_quantized,
828 bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
829 bw_row_sums, &op_data->fw_compute_row_sums,
830 &op_data->bw_compute_row_sums);
831 }
832 default:
833 context->ReportError(context, "Type not currently supported.");
834 return kTfLiteError;
835 }
836 return kTfLiteOk;
837 }
838
839 } // namespace bidirectional_sequence_rnn
840
Register_BIDIRECTIONAL_SEQUENCE_RNN()841 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
842 static TfLiteRegistration r = {
843 bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
844 bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
845 return &r;
846 }
847
848 } // namespace builtin
849 } // namespace ops
850 } // namespace tflite
851