1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstddef>
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace bidirectional_sequence_rnn {
30 
31 namespace {
32 
33 struct OpData {
34   int scratch_tensor_index;
35   bool fw_compute_row_sums = false;
36   bool bw_compute_row_sums = false;
37 };
38 
39 }  // namespace
40 
41 // LINT.IfChange
42 
43 constexpr int kInputTensor = 0;
44 // Forward and backward cell tensors.
45 constexpr int kFwWeightsTensor = 1;
46 constexpr int kFwRecurrentWeightsTensor = 2;
47 constexpr int kFwBiasTensor = 3;
48 constexpr int kFwHiddenStateTensor = 4;
49 constexpr int kBwWeightsTensor = 5;
50 constexpr int kBwRecurrentWeightsTensor = 6;
51 constexpr int kBwBiasTensor = 7;
52 constexpr int kBwHiddenStateTensor = 8;
53 // Used as auxiliary input and weights when stacking for
54 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
55 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
56 // (without cross links).
57 constexpr int kAuxInputTensor = 9;       // Optional.
58 constexpr int kFwAuxWeightsTensor = 10;  // Optional.
59 constexpr int kBwAuxWeightsTensor = 11;  // Optional.
60 // Output tensors.
61 constexpr int kFwOutputTensor = 0;
62 constexpr int kBwOutputTensor = 1;  // Only if merge_outputs is false.
63 
64 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
65 
66 // Temporary tensors.
67 enum TemporaryTensor {
68   kInputQuantized = 0,
69   kFwHiddenStateQuantized = 1,
70   kBwHiddenStateQuantized = 2,
71   kScalingFactors = 3,
72   kAccumScratch = 4,
73   kZeroPoints = 5,
74   kFwRowSums = 6,
75   kBwRowSums = 7,
76   kAuxInputQuantized = 8,
77   kNumTemporaryTensors = 9
78 };
79 
Init(TfLiteContext * context,const char * buffer,size_t length)80 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81   auto* op_data = new OpData();
82   context->AddTensors(context, kNumTemporaryTensors,
83                       &op_data->scratch_tensor_index);
84   return op_data;
85 }
86 
Free(TfLiteContext * context,void * buffer)87 void Free(TfLiteContext* context, void* buffer) {
88   delete reinterpret_cast<OpData*>(buffer);
89 }
90 
Prepare(TfLiteContext * context,TfLiteNode * node)91 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
92   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
93       node->builtin_data);
94 
95   // Check we have all the inputs and outputs we need.
96   TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
97   TF_LITE_ENSURE_EQ(context, node->outputs->size,
98                     params->merge_outputs ? 1 : 2);
99 
100   const TfLiteTensor* input;
101   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
102   const TfLiteTensor* fw_input_weights;
103   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
104                                           &fw_input_weights));
105   const TfLiteTensor* fw_recurrent_weights;
106   TF_LITE_ENSURE_OK(context,
107                     GetInputSafe(context, node, kFwRecurrentWeightsTensor,
108                                  &fw_recurrent_weights));
109   const TfLiteTensor* fw_bias;
110   TF_LITE_ENSURE_OK(context,
111                     GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
112   const TfLiteTensor* fw_hidden_state;
113   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
114                                           &fw_hidden_state));
115   const TfLiteTensor* bw_input_weights;
116   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
117                                           &bw_input_weights));
118   const TfLiteTensor* bw_recurrent_weights;
119   TF_LITE_ENSURE_OK(context,
120                     GetInputSafe(context, node, kBwRecurrentWeightsTensor,
121                                  &bw_recurrent_weights));
122   const TfLiteTensor* bw_bias;
123   TF_LITE_ENSURE_OK(context,
124                     GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
125   const TfLiteTensor* bw_hidden_state;
126   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
127                                           &bw_hidden_state));
128 
129   const TfLiteTensor* aux_input =
130       GetOptionalInputTensor(context, node, kAuxInputTensor);
131   const TfLiteTensor* fw_aux_input_weights =
132       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
133   const TfLiteTensor* bw_aux_input_weights =
134       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
135 
136   const bool aux_inputs_weights_or_none =
137       ((fw_aux_input_weights != nullptr) &&
138        (bw_aux_input_weights != nullptr)) ||
139       ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr));
140   TF_LITE_ENSURE(context, aux_inputs_weights_or_none);
141   const bool has_aux_input = (fw_aux_input_weights != nullptr);
142 
143   // Check all the parameters of tensor match within themselves and match the
144   // input configuration.
145   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
146 
147   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
148   const bool time_major = params->time_major;
149   const int batch_size =
150       (time_major) ? input->dims->data[1] : input->dims->data[0];
151   const int max_time =
152       (time_major) ? input->dims->data[0] : input->dims->data[1];
153   const int fw_num_units = fw_input_weights->dims->data[0];
154   const int bw_num_units = bw_input_weights->dims->data[0];
155   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
156                     fw_input_weights->dims->data[1]);
157   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
158                     bw_input_weights->dims->data[1]);
159   TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0],
160                     fw_bias->dims->data[0]);
161   TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0],
162                     bw_bias->dims->data[0]);
163   TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0],
164                     fw_bias->dims->data[0]);
165   TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1],
166                     bw_bias->dims->data[0]);
167   TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
168   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
169   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
170   TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
171   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
172   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
173 
174   if (has_aux_input) {
175     // Check that aux_input has the same dimensions (except last) as the input.
176     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
177     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
178     // Check that aux_input_weights has the same dimensions (except last) as
179     // the input_weights.
180     TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
181     TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
182     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
183                       fw_aux_input_weights->dims->data[1]);
184     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
185                       bw_aux_input_weights->dims->data[1]);
186   }
187 
188   if (IsHybridOp(input, fw_input_weights)) {
189     OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
190     op_data->fw_compute_row_sums = true;
191     op_data->bw_compute_row_sums = true;
192     TfLiteIntArrayFree(node->temporaries);
193     if (has_aux_input) {
194       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
195     } else {
196       // No need to create a temporary tensor for the non-existent aux_input.
197       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
198     }
199 
200     node->temporaries->data[kInputQuantized] =
201         op_data->scratch_tensor_index + kInputQuantized;
202     TfLiteTensor* input_quantized;
203     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
204                                                 &input_quantized));
205     input_quantized->type = fw_input_weights->type;
206     input_quantized->allocation_type = kTfLiteArenaRw;
207     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
208       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
209       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
210                                                        input_quantized_size));
211     }
212 
213     node->temporaries->data[kFwHiddenStateQuantized] =
214         op_data->scratch_tensor_index + kFwHiddenStateQuantized;
215     TfLiteTensor* fw_hidden_state_quantized;
216     TF_LITE_ENSURE_OK(context,
217                       GetTemporarySafe(context, node, kFwHiddenStateQuantized,
218                                        &fw_hidden_state_quantized));
219     fw_hidden_state_quantized->type = fw_input_weights->type;
220     fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
221     if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
222                              fw_hidden_state->dims)) {
223       TfLiteIntArray* fw_hidden_state_quantized_size =
224           TfLiteIntArrayCopy(fw_hidden_state->dims);
225       TF_LITE_ENSURE_OK(
226           context, context->ResizeTensor(context, fw_hidden_state_quantized,
227                                          fw_hidden_state_quantized_size));
228     }
229 
230     node->temporaries->data[kBwHiddenStateQuantized] =
231         op_data->scratch_tensor_index + kBwHiddenStateQuantized;
232     TfLiteTensor* bw_hidden_state_quantized;
233     TF_LITE_ENSURE_OK(context,
234                       GetTemporarySafe(context, node, kBwHiddenStateQuantized,
235                                        &bw_hidden_state_quantized));
236     bw_hidden_state_quantized->type = fw_input_weights->type;
237     bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
238     if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
239                              bw_hidden_state->dims)) {
240       TfLiteIntArray* bw_hidden_state_quantized_size =
241           TfLiteIntArrayCopy(bw_hidden_state->dims);
242       TF_LITE_ENSURE_OK(
243           context, context->ResizeTensor(context, bw_hidden_state_quantized,
244                                          bw_hidden_state_quantized_size));
245     }
246 
247     // Allocate temporary tensors to store scaling factors of quantization.
248     node->temporaries->data[kScalingFactors] =
249         op_data->scratch_tensor_index + kScalingFactors;
250     TfLiteTensor* scaling_factors;
251     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
252                                                 &scaling_factors));
253     scaling_factors->type = kTfLiteFloat32;
254     scaling_factors->allocation_type = kTfLiteArenaRw;
255     int scaling_dims[1] = {batch_size};
256     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
257       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
258       scaling_factors_size->data[0] = batch_size;
259       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
260                                                        scaling_factors_size));
261     }
262     node->temporaries->data[kAccumScratch] =
263         op_data->scratch_tensor_index + kAccumScratch;
264     TfLiteTensor* accum_scratch;
265     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
266                                                 &accum_scratch));
267     accum_scratch->type = kTfLiteInt32;
268     accum_scratch->allocation_type = kTfLiteArenaRw;
269     int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
270                                  batch_size};
271     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
272                                    accum_scratch_dims)) {
273       TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
274       accum_scratch_size->data[0] = accum_scratch_dims[0];
275       accum_scratch_size->data[1] = accum_scratch_dims[1];
276       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
277                                                        accum_scratch_size));
278     }
279     node->temporaries->data[kZeroPoints] =
280         op_data->scratch_tensor_index + kZeroPoints;
281     TfLiteTensor* zero_points;
282     TF_LITE_ENSURE_OK(
283         context,
284         GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
285     zero_points->type = kTfLiteInt32;
286     zero_points->allocation_type = kTfLiteArenaRw;
287     int zero_points_dims[1] = {batch_size};
288     if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
289       TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
290       zero_points_size->data[0] = batch_size;
291       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
292                                                        zero_points_size));
293     }
294     const int num_row_sums = has_aux_input ? 3 : 2;
295     node->temporaries->data[kFwRowSums] =
296         op_data->scratch_tensor_index + kFwRowSums;
297     TfLiteTensor* fw_row_sums;
298     TF_LITE_ENSURE_OK(
299         context,
300         GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
301     fw_row_sums->type = kTfLiteInt32;
302     fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
303     int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
304     if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
305       TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
306       fw_row_sums_size->data[0] = fw_row_sums_dims[0];
307       fw_row_sums_size->data[1] = fw_row_sums_dims[1];
308       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
309                                                        fw_row_sums_size));
310     }
311     node->temporaries->data[kBwRowSums] =
312         op_data->scratch_tensor_index + kBwRowSums;
313     TfLiteTensor* bw_row_sums;
314     TF_LITE_ENSURE_OK(
315         context,
316         GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
317     bw_row_sums->type = kTfLiteInt32;
318     bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
319     int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
320     if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
321       TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
322       bw_row_sums_size->data[0] = bw_row_sums_dims[0];
323       bw_row_sums_size->data[1] = bw_row_sums_dims[1];
324       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
325                                                        bw_row_sums_size));
326     }
327     if (has_aux_input) {
328       node->temporaries->data[kAuxInputQuantized] =
329           op_data->scratch_tensor_index + kAuxInputQuantized;
330       TfLiteTensor* aux_input_quantized;
331       TF_LITE_ENSURE_OK(context,
332                         GetTemporarySafe(context, node, kAuxInputQuantized,
333                                          &aux_input_quantized));
334       aux_input_quantized->type = fw_input_weights->type;
335       aux_input_quantized->allocation_type = kTfLiteArenaRw;
336       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
337         TfLiteIntArray* aux_input_quantized_size =
338             TfLiteIntArrayCopy(aux_input->dims);
339         TF_LITE_ENSURE_OK(context,
340                           context->ResizeTensor(context, aux_input_quantized,
341                                                 aux_input_quantized_size));
342       }
343     }
344   }
345 
346   // Resize outputs.
347   TfLiteTensor* fw_output;
348   TF_LITE_ENSURE_OK(context,
349                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
350   TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
351   fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
352   fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
353   fw_output_size_array->data[2] =
354       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
355   TF_LITE_ENSURE_OK(
356       context, context->ResizeTensor(context, fw_output, fw_output_size_array));
357   if (!params->merge_outputs) {
358     TfLiteTensor* bw_output;
359     TF_LITE_ENSURE_OK(
360         context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
361     TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
362     bw_output_size_array->data[0] = batch_size;
363     bw_output_size_array->data[1] = max_time;
364     bw_output_size_array->data[2] = bw_num_units;
365     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
366                                                      bw_output_size_array));
367   }
368 
369   return kTfLiteOk;
370 }
371 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * fw_aux_input_weights,const TfLiteTensor * bw_aux_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output)372 TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
373                        const TfLiteTensor* fw_input_weights,
374                        const TfLiteTensor* fw_recurrent_weights,
375                        const TfLiteTensor* fw_bias,
376                        const TfLiteTensor* bw_input_weights,
377                        const TfLiteTensor* bw_recurrent_weights,
378                        const TfLiteTensor* bw_bias,
379                        const TfLiteTensor* aux_input,
380                        const TfLiteTensor* fw_aux_input_weights,
381                        const TfLiteTensor* bw_aux_input_weights,
382                        const TfLiteBidirectionalSequenceRNNParams* params,
383                        TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
384                        TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
385   const bool time_major = params->time_major;
386   const int batch_size =
387       (time_major) ? input->dims->data[1] : input->dims->data[0];
388   const int max_time =
389       (time_major) ? input->dims->data[0] : input->dims->data[1];
390   const int input_size = input->dims->data[2];
391   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
392 
393   const int fw_num_units = fw_input_weights->dims->data[0];
394   const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
395   const float* fw_input_weights_ptr = GetTensorData<float>(fw_input_weights);
396   const float* fw_recurrent_weights_ptr =
397       GetTensorData<float>(fw_recurrent_weights);
398 
399   const int bw_num_units = bw_input_weights->dims->data[0];
400   const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
401   const float* bw_input_weights_ptr = GetTensorData<float>(bw_input_weights);
402   const float* bw_recurrent_weights_ptr =
403       GetTensorData<float>(bw_recurrent_weights);
404 
405   const float* fw_aux_input_weights_ptr =
406       (fw_aux_input_weights != nullptr)
407           ? GetTensorData<float>(fw_aux_input_weights)
408           : nullptr;
409   const float* bw_aux_input_weights_ptr =
410       (bw_aux_input_weights != nullptr)
411           ? GetTensorData<float>(bw_aux_input_weights)
412           : nullptr;
413 
414   const int fw_output_step =
415       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
416   const int bw_output_step =
417       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
418   if (time_major) {
419     // Forward cell.
420     float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
421     for (int s = 0; s < max_time; s++) {
422       const float* input_ptr_batch =
423           GetTensorData<float>(input) + s * input_size * batch_size;
424       const float* aux_input_ptr_batch =
425           (aux_input != nullptr)
426               ? GetTensorData<float>(aux_input) + s * input_size * batch_size
427               : nullptr;
428       float* output_ptr_batch =
429           GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
430 
431       kernel_utils::RnnBatchStep(
432           input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
433           fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
434           input_size, aux_input_size, fw_num_units, batch_size, fw_output_step,
435           params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
436     }
437     // Backward cell.
438     float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
439     for (int s = max_time - 1; s >= 0; s--) {
440       const float* input_ptr_batch =
441           GetTensorData<float>(bw_input) + s * input_size * batch_size;
442       const float* aux_input_ptr_batch =
443           (aux_input != nullptr)
444               ? GetTensorData<float>(aux_input) + s * input_size * batch_size
445               : nullptr;
446       float* output_ptr_batch =
447           (params->merge_outputs
448                ? GetTensorData<float>(fw_output) + fw_num_units
449                : GetTensorData<float>(bw_output)) +
450           s * bw_output_step * batch_size;
451 
452       kernel_utils::RnnBatchStep(
453           input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
454           bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
455           input_size, aux_input_size, bw_num_units, batch_size, bw_output_step,
456           params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
457     }
458   } else {
459     for (int b = 0; b < batch_size; b++) {
460       // Forward cell.
461       float* fw_hidden_state_ptr_batch =
462           GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
463       float* fw_output_offset =
464           GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
465       for (int s = 0; s < max_time; s++) {
466         const float* input_ptr_batch = GetTensorData<float>(input) +
467                                        b * input_size * max_time +
468                                        s * input_size;
469         const float* aux_input_ptr_batch =
470             (aux_input != nullptr)
471                 ? GetTensorData<float>(aux_input) +
472                       b * aux_input_size * max_time + s * aux_input_size
473                 : nullptr;
474         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
475 
476         kernel_utils::RnnBatchStep(
477             input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
478             fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
479             input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
480             fw_output_step, params->activation, fw_hidden_state_ptr_batch,
481             output_ptr_batch);
482       }
483       // Backward cell.
484       float* bw_hidden_state_ptr_batch =
485           GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
486       float* bw_output_offset =
487           params->merge_outputs
488               ? GetTensorData<float>(fw_output) +
489                     b * bw_output_step * max_time + fw_num_units
490               : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
491       for (int s = max_time - 1; s >= 0; s--) {
492         const float* input_ptr_batch = GetTensorData<float>(input) +
493                                        b * input_size * max_time +
494                                        s * input_size;
495         const float* aux_input_ptr_batch =
496             (aux_input != nullptr)
497                 ? GetTensorData<float>(aux_input) +
498                       b * aux_input_size * max_time + s * aux_input_size
499                 : nullptr;
500         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
501 
502         kernel_utils::RnnBatchStep(
503             input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
504             bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
505             input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
506             bw_output_step, params->activation, bw_hidden_state_ptr_batch,
507             output_ptr_batch);
508       }
509     }
510   }
511   return kTfLiteOk;
512 }
513 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * aux_fw_input_weights,const TfLiteTensor * aux_bw_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * scaling_factors,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * fw_hidden_state_quantized,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state_quantized,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output,TfLiteTensor * zero_points,TfLiteTensor * accum_scratch,TfLiteTensor * fw_row_sums,TfLiteTensor * bw_row_sums,bool * fw_compute_row_sums,bool * bw_compute_row_sums)514 TfLiteStatus EvalHybrid(
515     const TfLiteTensor* input, const TfLiteTensor* bw_input,
516     const TfLiteTensor* fw_input_weights,
517     const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
518     const TfLiteTensor* bw_input_weights,
519     const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
520     const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
521     const TfLiteTensor* aux_bw_input_weights,
522     const TfLiteBidirectionalSequenceRNNParams* params,
523     TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
524     TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
525     TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
526     TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
527     TfLiteTensor* bw_output, TfLiteTensor* zero_points,
528     TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
529     TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
530     bool* bw_compute_row_sums) {
531   const bool time_major = params->time_major;
532   const int batch_size =
533       (time_major) ? input->dims->data[1] : input->dims->data[0];
534   const int max_time =
535       (time_major) ? input->dims->data[0] : input->dims->data[1];
536   const int input_size = input->dims->data[2];
537   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
538 
539   const int fw_num_units = fw_input_weights->dims->data[0];
540   const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
541   const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights);
542   float fw_input_weights_scale = fw_input_weights->params.scale;
543   const int8_t* fw_recurrent_weights_ptr =
544       GetTensorData<int8_t>(fw_recurrent_weights);
545   float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
546 
547   const int bw_num_units = bw_input_weights->dims->data[0];
548   const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
549   const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights);
550   float bw_input_weights_scale = bw_input_weights->params.scale;
551   const int8_t* bw_recurrent_weights_ptr =
552       GetTensorData<int8_t>(bw_recurrent_weights);
553   float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
554 
555   // Set the auxiliary pointers and scales if needed.
556   const int8_t* aux_fw_input_weights_ptr = nullptr;
557   float aux_fw_input_weights_scale = 0.0f;
558   const int8_t* aux_bw_input_weights_ptr = nullptr;
559   float aux_bw_input_weights_scale = 0.0f;
560   int8_t* aux_quantized_input_ptr = nullptr;
561   if (aux_input_size > 0) {
562     aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights);
563     aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
564     aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights);
565     aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
566     aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized);
567   }
568 
569   // Initialize temporary storage for quantized values.
570   int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
571   int8_t* fw_quantized_hidden_state_ptr =
572       GetTensorData<int8_t>(fw_hidden_state_quantized);
573   int8_t* bw_quantized_hidden_state_ptr =
574       GetTensorData<int8_t>(bw_hidden_state_quantized);
575   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
576   int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
577   int32_t* zero_points_ptr = nullptr;
578   int32_t* fw_row_sums_ptr = nullptr;
579   int32_t* bw_row_sums_ptr = nullptr;
580   if (params->asymmetric_quantize_inputs) {
581     zero_points_ptr = GetTensorData<int32_t>(zero_points);
582     fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
583     bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
584   }
585   const int fw_output_step =
586       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
587   const int bw_output_step =
588       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
589 
590   if (time_major) {
591     for (int t = 0; t < max_time; t++) {
592       // Forward cell.
593       float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
594       for (int s = 0; s < max_time; s++) {
595         const float* input_ptr_batch =
596             GetTensorData<float>(input) + s * input_size * batch_size;
597         const float* aux_input_ptr_batch =
598             (aux_input != nullptr)
599                 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
600                 : nullptr;
601         float* output_ptr_batch =
602             GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
603 
604         kernel_utils::RnnBatchStep(
605             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
606             aux_input_ptr_batch, aux_fw_input_weights_ptr,
607             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
608             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
609             fw_num_units, batch_size, fw_output_step, params->activation,
610             quantized_input_ptr, aux_quantized_input_ptr,
611             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
612             fw_hidden_state_ptr_batch, output_ptr_batch,
613             params->asymmetric_quantize_inputs, zero_points_ptr,
614             accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
615       }
616       // Backward cell.
617       float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
618       for (int s = max_time - 1; s >= 0; s--) {
619         const float* input_ptr_batch =
620             GetTensorData<float>(bw_input) + s * input_size * batch_size;
621         const float* aux_input_ptr_batch =
622             (aux_input != nullptr)
623                 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
624                 : nullptr;
625         float* output_ptr_batch =
626             (params->merge_outputs
627                  ? GetTensorData<float>(fw_output) + fw_num_units
628                  : GetTensorData<float>(bw_output)) +
629             s * bw_output_step * batch_size;
630 
631         kernel_utils::RnnBatchStep(
632             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
633             aux_input_ptr_batch, aux_bw_input_weights_ptr,
634             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
635             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
636             bw_num_units, batch_size, bw_output_step, params->activation,
637             quantized_input_ptr, aux_quantized_input_ptr,
638             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
639             bw_hidden_state_ptr_batch, output_ptr_batch,
640             params->asymmetric_quantize_inputs, zero_points_ptr,
641             accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
642       }
643     }
644   } else {
645     for (int b = 0; b < batch_size; b++) {
646       // Forward cell.
647       float* fw_hidden_state_ptr_batch =
648           GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
649       float* fw_output_offset =
650           GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
651       for (int s = 0; s < max_time; s++) {
652         const float* input_ptr_batch = GetTensorData<float>(input) +
653                                        b * input_size * max_time +
654                                        s * input_size;
655         const float* aux_input_ptr_batch =
656             (aux_input != nullptr)
657                 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
658                       s * input_size
659                 : nullptr;
660         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
661 
662         kernel_utils::RnnBatchStep(
663             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
664             aux_input_ptr_batch, aux_fw_input_weights_ptr,
665             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
666             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
667             fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
668             quantized_input_ptr, aux_quantized_input_ptr,
669             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
670             fw_hidden_state_ptr_batch, output_ptr_batch,
671             params->asymmetric_quantize_inputs, zero_points_ptr,
672             accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
673       }
674       // Backward cell.
675       float* bw_hidden_state_ptr_batch =
676           GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
677       float* bw_output_offset =
678           params->merge_outputs
679               ? GetTensorData<float>(fw_output) +
680                     b * bw_output_step * max_time + fw_num_units
681               : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
682       for (int s = max_time - 1; s >= 0; s--) {
683         const float* input_ptr_batch = GetTensorData<float>(input) +
684                                        b * input_size * max_time +
685                                        s * input_size;
686         const float* aux_input_ptr_batch =
687             (aux_input != nullptr)
688                 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
689                       s * input_size
690                 : nullptr;
691         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
692 
693         kernel_utils::RnnBatchStep(
694             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
695             aux_input_ptr_batch, aux_bw_input_weights_ptr,
696             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
697             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
698             bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
699             quantized_input_ptr, aux_quantized_input_ptr,
700             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
701             bw_hidden_state_ptr_batch, output_ptr_batch,
702             params->asymmetric_quantize_inputs, zero_points_ptr,
703             accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
704       }
705     }
706   }
707   return kTfLiteOk;
708 }
709 
Eval(TfLiteContext * context,TfLiteNode * node)710 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
711   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
712       node->builtin_data);
713 
714   const TfLiteTensor* input;
715   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
716   const TfLiteTensor* fw_input_weights;
717   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
718                                           &fw_input_weights));
719   const TfLiteTensor* fw_recurrent_weights;
720   TF_LITE_ENSURE_OK(context,
721                     GetInputSafe(context, node, kFwRecurrentWeightsTensor,
722                                  &fw_recurrent_weights));
723   const TfLiteTensor* fw_bias;
724   TF_LITE_ENSURE_OK(context,
725                     GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
726   const TfLiteTensor* bw_input_weights;
727   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
728                                           &bw_input_weights));
729   const TfLiteTensor* bw_recurrent_weights;
730   TF_LITE_ENSURE_OK(context,
731                     GetInputSafe(context, node, kBwRecurrentWeightsTensor,
732                                  &bw_recurrent_weights));
733   const TfLiteTensor* bw_bias;
734   TF_LITE_ENSURE_OK(context,
735                     GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
736 
737   // Get auxiliary inputs.
738   const TfLiteTensor* aux_input =
739       GetOptionalInputTensor(context, node, kAuxInputTensor);
740   const TfLiteTensor* fw_aux_input_weights =
741       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
742   const TfLiteTensor* bw_aux_input_weights =
743       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
744 
745   TfLiteTensor* fw_hidden_state =
746       GetVariableInput(context, node, kFwHiddenStateTensor);
747   TFLITE_DCHECK(fw_hidden_state != nullptr);
748   TfLiteTensor* bw_hidden_state =
749       GetVariableInput(context, node, kBwHiddenStateTensor);
750   TFLITE_DCHECK(bw_hidden_state != nullptr);
751 
752   TfLiteTensor* fw_output;
753   TF_LITE_ENSURE_OK(context,
754                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
755   TfLiteTensor* bw_output = params->merge_outputs
756                                 ? nullptr
757                                 : GetOutput(context, node, kBwOutputTensor);
758 
759   const bool has_previous_bw_output = (aux_input != nullptr);
760   const bool use_aux_input = (fw_aux_input_weights != nullptr);
761 
762   // We want to cover the following cases:
763   //
764   // If not stacking (not connected after other bidi lstms):
765   //   both fw & bw will just use `input`; aux_input will be null.
766   //
767   // If stacking with cross_links, TensorFlow equivalent
768   // (tf.contrib.rnn.stack_bidirectional_rnn):
769   //   both fw & bw will use `input`, but aux_input will be none null.
770   //   Note, this time, whether connected after other bidi lstms both works.
771   //
772   // If stacking without cross_links, but connected after other bidi lstms,
773   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
774   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
775   //   will be null.
776 
777   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
778   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
779   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
780 
781   switch (fw_input_weights->type) {
782     case kTfLiteFloat32:
783       return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights,
784                        fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias,
785                        real_aux_input, fw_aux_input_weights,
786                        bw_aux_input_weights, params, fw_hidden_state, fw_output,
787                        bw_hidden_state, bw_output);
788     case kTfLiteUInt8:
789     case kTfLiteInt8: {
790       TfLiteTensor* input_quantized;
791       TF_LITE_ENSURE_OK(
792           context,
793           GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
794       TfLiteTensor* fw_hidden_state_quantized;
795       TF_LITE_ENSURE_OK(context,
796                         GetTemporarySafe(context, node, kFwHiddenStateQuantized,
797                                          &fw_hidden_state_quantized));
798       TfLiteTensor* bw_hidden_state_quantized;
799       TF_LITE_ENSURE_OK(context,
800                         GetTemporarySafe(context, node, kBwHiddenStateQuantized,
801                                          &bw_hidden_state_quantized));
802       TfLiteTensor* scaling_factors;
803       TF_LITE_ENSURE_OK(
804           context,
805           GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
806       TfLiteTensor* zero_points;
807       TF_LITE_ENSURE_OK(
808           context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
809       TfLiteTensor* accum_scratch;
810       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
811                                                   &accum_scratch));
812       TfLiteTensor* fw_row_sums;
813       TF_LITE_ENSURE_OK(
814           context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
815       TfLiteTensor* bw_row_sums;
816       TF_LITE_ENSURE_OK(
817           context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
818       TfLiteTensor* aux_input_quantized =
819           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
820                         : nullptr;
821       auto* op_data = reinterpret_cast<OpData*>(node->user_data);
822       return EvalHybrid(
823           input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
824           bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
825           fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
826           input_quantized, aux_input_quantized, fw_hidden_state_quantized,
827           fw_hidden_state, fw_output, bw_hidden_state_quantized,
828           bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
829           bw_row_sums, &op_data->fw_compute_row_sums,
830           &op_data->bw_compute_row_sums);
831     }
832     default:
833       context->ReportError(context, "Type not currently supported.");
834       return kTfLiteError;
835   }
836   return kTfLiteOk;
837 }
838 
839 }  // namespace bidirectional_sequence_rnn
840 
Register_BIDIRECTIONAL_SEQUENCE_RNN()841 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
842   static TfLiteRegistration r = {
843       bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
844       bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
845   return &r;
846 }
847 
848 }  // namespace builtin
849 }  // namespace ops
850 }  // namespace tflite
851