1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 
18 #include <algorithm>
19 #include <cstddef>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
26 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/lstm_eval.h"
29 #include "tensorflow/lite/kernels/op_macros.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace builtin {
34 namespace bidirectional_sequence_lstm {
35 
36 // LINT.IfChange
37 
38 // Input Tensors of size {max_time, n_batch, n_input}
39 constexpr int kInputTensor = 0;
40 
41 // Forward LSTM cell tensors.
42 // Input weight tensors of size: {n_cell, n_input}
43 constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
44 constexpr int kFwInputToForgetWeightsTensor = 2;
45 constexpr int kFwInputToCellWeightsTensor = 3;
46 constexpr int kFwInputToOutputWeightsTensor = 4;
47 
48 // Recurrent weight tensors of size {n_cell, n_output}
49 constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
50 constexpr int kFwRecurrentToForgetWeightsTensor = 6;
51 constexpr int kFwRecurrentToCellWeightsTensor = 7;
52 constexpr int kFwRecurrentToOutputWeightsTensor = 8;
53 
54 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
55 constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
56 constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
57 constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
58 
59 // Gates bias tensors of size {n_cell}
60 constexpr int kFwInputGateBiasTensor = 12;  // Optional
61 constexpr int kFwForgetGateBiasTensor = 13;
62 constexpr int kFwCellGateBiasTensor = 14;
63 constexpr int kFwOutputGateBiasTensor = 15;
64 
65 // Projection weight tensor of size {n_output, n_cell}
66 constexpr int kFwProjectionWeightsTensor = 16;  // Optional
67 // Projection bias tensor of size {n_output}
68 constexpr int kFwProjectionBiasTensor = 17;  // Optional
69 
70 // Backward LSTM cell tensors.
71 // Input weight tensors of size: {n_cell, n_input}
72 constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
73 constexpr int kBwInputToForgetWeightsTensor = 19;
74 constexpr int kBwInputToCellWeightsTensor = 20;
75 constexpr int kBwInputToOutputWeightsTensor = 21;
76 
77 // Recurrent weight tensors of size {n_cell, n_output}
78 constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
79 constexpr int kBwRecurrentToForgetWeightsTensor = 23;
80 constexpr int kBwRecurrentToCellWeightsTensor = 24;
81 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
82 
83 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
84 constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
85 constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
86 constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
87 
88 // Gates bias tensors of size {n_cell}
89 constexpr int kBwInputGateBiasTensor = 29;  // Optional
90 constexpr int kBwForgetGateBiasTensor = 30;
91 constexpr int kBwCellGateBiasTensor = 31;
92 constexpr int kBwOutputGateBiasTensor = 32;
93 
94 // Projection weight tensor of size {n_output, n_cell}
95 constexpr int kBwProjectionWeightsTensor = 33;  // Optional
96 // Projection bias tensor of size {n_output}
97 constexpr int kBwProjectionBiasTensor = 34;  // Optional
98 
99 // Stateful input tensors that are variables and will be modified by the Op.
100 // Activation state tensors of size {n_batch, n_output}
101 constexpr int kFwInputActivationStateTensor = 35;
102 // Cell state tensors of size {n_batch, n_cell}
103 constexpr int kFwInputCellStateTensor = 36;
104 // Activation state tensors of size {n_batch, n_output}
105 constexpr int kBwInputActivationStateTensor = 37;
106 // Cell state tensors of size {n_batch, n_cell}
107 constexpr int kBwInputCellStateTensor = 38;
108 
109 // Used as auxiliary input and weights when stacking for
110 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
111 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
112 // (without cross links).
113 constexpr int kAuxInputTensor = 39;  // Optional
114 // Forward weights.
115 constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
116 constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
117 constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
118 constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
119 // Backward weights.
120 constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
121 constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
122 constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
123 constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
124 
125 // Output tensors.
126 constexpr int kFwOutputTensor = 0;
127 constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
128 
129 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
130 
131 // Temporary tensors.
132 enum TemporaryTensor {
133   // Scratch buffers for input, forget, etc. gates
134   kFwScratchBuffer = 0,
135   kBwScratchBuffer = 1,
136   // Quantized tensors needed for the hybrid kernel.
137   kInputQuantized = 2,
138   kFwActivationStateQuantized = 3,
139   kBwActivationStateQuantized = 4,
140   kFwCellStateQuantized = 5,
141   kBwCellStateQuantized = 6,
142   kInputScalingFactors = 7,
143   kAuxInputScalingFactors = 8,
144   kOutputStateScalingFactors = 9,
145   kProductScalingFactors = 10,
146   kRecoveredCellWeights = 11,
147   kAccumScratchBuffer = 12,
148   kInputZeroPoints = 13,
149   kAuxInputZeroPoints = 14,
150   kOutputStateZeroPoints = 15,
151   kFwRowSums = 16,
152   kBwRowSums = 17,
153   kAuxInputQuantized = 18,  // Optional, quantized tensor for auxiliary input.
154   kNumTemporaryTensors = 19,
155 };
156 
157 struct OpData {
158   int scratch_tensor_index;
159   bool compute_fw_row_sums = false;
160   bool compute_bw_row_sums = false;
161 };
162 
Init(TfLiteContext * context,const char * buffer,size_t length)163 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
164   auto* op_data = new OpData();
165   context->AddTensors(context, kNumTemporaryTensors,
166                       &op_data->scratch_tensor_index);
167   return op_data;
168 }
169 
Free(TfLiteContext * context,void * buffer)170 void Free(TfLiteContext* context, void* buffer) {
171   delete reinterpret_cast<OpData*>(buffer);
172 }
173 
174 // Check that input tensor dimensions matches with each other.
CheckLstmTensorDimensionsAndTypes(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,int input_to_input_weights_tensor,int input_to_forget_weights_tensor,int input_to_cell_weights_tensor,int input_to_output_weights_tensor,int recurrent_to_input_weights_tensor,int recurrent_to_forget_weights_tensor,int recurrent_to_cell_weights_tensor,int recurrent_to_output_weights_tensor,int cell_to_input_weights_tensor,int cell_to_forget_weights_tensor,int cell_to_output_weights_tensor,int input_gate_bias_tensor,int forget_gate_bias_tensor,int cell_gate_bias_tensor,int output_gate_bias_tensor,int projection_weights_tensor,int projection_bias_tensor)175 TfLiteStatus CheckLstmTensorDimensionsAndTypes(
176     TfLiteContext* context, TfLiteNode* node, int n_input, int n_output,
177     int n_cell, int input_to_input_weights_tensor,
178     int input_to_forget_weights_tensor, int input_to_cell_weights_tensor,
179     int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor,
180     int recurrent_to_forget_weights_tensor,
181     int recurrent_to_cell_weights_tensor,
182     int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor,
183     int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor,
184     int input_gate_bias_tensor, int forget_gate_bias_tensor,
185     int cell_gate_bias_tensor, int output_gate_bias_tensor,
186     int projection_weights_tensor, int projection_bias_tensor) {
187   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
188       node->builtin_data);
189 
190   // Making sure clipping parameters have valid values.
191   // == 0 means no clipping
192   //  > 0 means clipping
193   TF_LITE_ENSURE(context, params->cell_clip >= 0);
194   TF_LITE_ENSURE(context, params->proj_clip >= 0);
195 
196   const TfLiteTensor* input_to_forget_weights;
197   TF_LITE_ENSURE_OK(context,
198                     GetInputSafe(context, node, input_to_forget_weights_tensor,
199                                  &input_to_forget_weights));
200   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
201   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
202   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
203   TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
204                               (input_to_forget_weights->type == kTfLiteInt8) ||
205                               (input_to_forget_weights->type == kTfLiteUInt8));
206 
207   const TfLiteTensor* input_to_input_weights =
208       GetOptionalInputTensor(context, node, input_to_input_weights_tensor);
209   if (input_to_input_weights != nullptr) {
210     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
211     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
212     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
213     TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
214                             input_to_forget_weights->type);
215   }
216 
217   const TfLiteTensor* input_to_cell_weights;
218   TF_LITE_ENSURE_OK(context,
219                     GetInputSafe(context, node, input_to_cell_weights_tensor,
220                                  &input_to_cell_weights));
221   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
222   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
223   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
224   TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
225                           input_to_forget_weights->type);
226 
227   const TfLiteTensor* input_to_output_weights;
228   TF_LITE_ENSURE_OK(context,
229                     GetInputSafe(context, node, input_to_output_weights_tensor,
230                                  &input_to_output_weights));
231   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
232   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
233   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
234   TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type,
235                           input_to_forget_weights->type);
236 
237   const TfLiteTensor* recurrent_to_input_weights =
238       GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor);
239   if (recurrent_to_input_weights != nullptr) {
240     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
241     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
242                       n_cell);
243     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
244                       n_output);
245     TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
246                             input_to_forget_weights->type);
247   }
248 
249   const TfLiteTensor* recurrent_to_forget_weights;
250   TF_LITE_ENSURE_OK(
251       context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
252                             &recurrent_to_forget_weights));
253   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
254   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
255                     n_cell);
256   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
257                     n_output);
258   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
259                           input_to_forget_weights->type);
260 
261   const TfLiteTensor* recurrent_to_cell_weights;
262   TF_LITE_ENSURE_OK(
263       context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
264                             &recurrent_to_cell_weights));
265   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
266   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
267   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
268                     n_output);
269   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
270                           input_to_forget_weights->type);
271 
272   // We make sure the input-gate's parameters are either both present (regular
273   // LSTM) or not at all (CIFG-LSTM).
274   const bool cifg_weights_all_or_none =
275       ((input_to_input_weights != nullptr) &&
276        (recurrent_to_input_weights != nullptr)) ||
277       ((input_to_input_weights == nullptr) &&
278        (recurrent_to_input_weights == nullptr));
279   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
280 
281   const TfLiteTensor* cell_to_input_weights =
282       GetOptionalInputTensor(context, node, cell_to_input_weights_tensor);
283   if (cell_to_input_weights != nullptr) {
284     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
285     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
286     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type,
287                             input_to_forget_weights->type);
288   }
289 
290   const TfLiteTensor* cell_to_forget_weights =
291       GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor);
292   if (cell_to_forget_weights != nullptr) {
293     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
294     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
295     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type,
296                             input_to_forget_weights->type);
297   }
298 
299   const TfLiteTensor* cell_to_output_weights =
300       GetOptionalInputTensor(context, node, cell_to_output_weights_tensor);
301   if (cell_to_output_weights != nullptr) {
302     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
303     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
304     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type,
305                             input_to_forget_weights->type);
306   }
307 
308   // Making sure the peephole weights are there all or none.
309   const bool use_cifg = (input_to_input_weights == nullptr);
310   const bool peephole_weights_all_or_none =
311       ((cell_to_input_weights != nullptr || use_cifg) &&
312        (cell_to_forget_weights != nullptr) &&
313        (cell_to_output_weights != nullptr)) ||
314       ((cell_to_input_weights == nullptr) &&
315        (cell_to_forget_weights == nullptr) &&
316        (cell_to_output_weights == nullptr));
317   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
318 
319   // Make sure the input gate bias is present only when not a CIFG-LSTM.
320   const TfLiteTensor* input_gate_bias =
321       GetOptionalInputTensor(context, node, input_gate_bias_tensor);
322   if (use_cifg) {
323     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
324   } else {
325     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
326     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
327     TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
328   }
329 
330   const TfLiteTensor* forget_gate_bias;
331   TF_LITE_ENSURE_OK(
332       context,
333       GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
334   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
335   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
336   TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
337 
338   const TfLiteTensor* cell_gate_bias;
339   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
340                                           &cell_gate_bias));
341   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
342   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
343   TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
344 
345   const TfLiteTensor* output_gate_bias;
346   TF_LITE_ENSURE_OK(
347       context,
348       GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
349   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
350   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
351   TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
352 
353   const TfLiteTensor* projection_weights =
354       GetOptionalInputTensor(context, node, projection_weights_tensor);
355   if (projection_weights != nullptr) {
356     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
357     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
358     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
359     TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
360                             input_to_forget_weights->type);
361   }
362 
363   const TfLiteTensor* projection_bias =
364       GetOptionalInputTensor(context, node, projection_bias_tensor);
365   if (projection_bias != nullptr) {
366     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
367     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
368     TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
369   }
370 
371   // Making sure the projection tensors are consistent:
372   // 1) If projection weight is not present, then projection bias should not be
373   // present.
374   // 2) If projection weight is present, then projection bias is optional.
375   // TODO(ghodrat): make sure this is correct.
376   const bool projecton_tensors_consistent =
377       ((projection_weights != nullptr) || (projection_bias == nullptr));
378   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
379 
380   return kTfLiteOk;
381 }
382 
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell)383 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
384                                         TfLiteNode* node, int n_input,
385                                         int n_output, int n_cell) {
386   TF_LITE_ENSURE_OK(
387       context,
388       CheckLstmTensorDimensionsAndTypes(
389           context, node, n_input, n_output, n_cell,
390           kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
391           kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
392           kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
393           kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
394           kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
395           kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
396           kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
397           kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
398           kFwProjectionBiasTensor));
399 
400   TF_LITE_ENSURE_OK(
401       context,
402       CheckLstmTensorDimensionsAndTypes(
403           context, node, n_input, n_output, n_cell,
404           kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
405           kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
406           kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
407           kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
408           kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
409           kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
410           kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
411           kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
412           kBwProjectionBiasTensor));
413 
414   // Check if Forward and Backward tensors match along required dimensions.
415   return kTfLiteOk;
416 }
417 
418 // Resize the output and scratch tensors based on the sizes of the input
419 // tensors. Also check that the size of the input tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)420 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
421   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
422   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
423       node->builtin_data);
424 
425   // Check we have all the inputs and outputs we need.
426   TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
427   TF_LITE_ENSURE_EQ(context, node->outputs->size,
428                     params->merge_outputs ? 1 : 2);
429 
430   // Inferring batch size, number of outputs and sequence length and
431   // number of cells from the input tensors.
432   const TfLiteTensor* input;
433   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
434   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
435   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
436   const bool time_major = params->time_major;
437   const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
438   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
439   const int n_input = input->dims->data[2];
440 
441   const TfLiteTensor* fw_input_to_output_weights;
442   TF_LITE_ENSURE_OK(context,
443                     GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
444                                  &fw_input_to_output_weights));
445   const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
446   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
447   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
448                     n_input);
449 
450   const TfLiteTensor* bw_input_to_output_weights;
451   TF_LITE_ENSURE_OK(context,
452                     GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
453                                  &bw_input_to_output_weights));
454   const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
455   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
456   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
457                     n_input);
458   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
459                     fw_input_to_output_weights->type);
460 
461   const TfLiteTensor* fw_recurrent_to_output_weights;
462   TF_LITE_ENSURE_OK(
463       context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
464                             &fw_recurrent_to_output_weights));
465   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
466   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
467                     n_fw_cell);
468   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type,
469                     fw_input_to_output_weights->type);
470   const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
471 
472   const TfLiteTensor* bw_recurrent_to_output_weights;
473   TF_LITE_ENSURE_OK(
474       context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
475                             &bw_recurrent_to_output_weights));
476   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
477   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
478                     n_bw_cell);
479   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type,
480                     fw_input_to_output_weights->type);
481   const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
482 
483   // Check that input tensor dimensions matches with each other.
484   TF_LITE_ENSURE_OK(
485       context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
486                                           n_fw_cell));
487 
488   // Get (optional) auxiliary inputs and weights.
489   const TfLiteTensor* aux_input =
490       GetOptionalInputTensor(context, node, kAuxInputTensor);
491   const TfLiteTensor* fw_aux_input_to_input_weights =
492       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
493   const TfLiteTensor* fw_aux_input_to_forget_weights =
494       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
495   const TfLiteTensor* fw_aux_input_to_cell_weights =
496       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
497   const TfLiteTensor* fw_aux_input_to_output_weights =
498       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
499   const TfLiteTensor* bw_aux_input_to_input_weights =
500       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
501   const TfLiteTensor* bw_aux_input_to_forget_weights =
502       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
503   const TfLiteTensor* bw_aux_input_to_cell_weights =
504       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
505   const TfLiteTensor* bw_aux_input_to_output_weights =
506       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
507 
508   const bool aux_inputs_weights_all_or_none =
509       ((fw_aux_input_to_cell_weights != nullptr) &&
510        (fw_aux_input_to_forget_weights != nullptr) &&
511        (fw_aux_input_to_output_weights != nullptr) &&
512        (bw_aux_input_to_cell_weights != nullptr) &&
513        (bw_aux_input_to_forget_weights != nullptr) &&
514        (bw_aux_input_to_output_weights != nullptr)) ||
515       ((fw_aux_input_to_cell_weights == nullptr) &&
516        (fw_aux_input_to_forget_weights == nullptr) &&
517        (fw_aux_input_to_output_weights == nullptr) &&
518        (bw_aux_input_to_cell_weights == nullptr) &&
519        (bw_aux_input_to_forget_weights == nullptr) &&
520        (bw_aux_input_to_output_weights == nullptr));
521   TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
522 
523   const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
524 
525   if (has_aux_input) {
526     // Check that aux_input has the same dimensions (except last) as the input.
527     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
528     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
529   }
530 
531   // Get the pointer to output, activation_state and cell_state buffer tensors.
532   TfLiteTensor* fw_output;
533   TF_LITE_ENSURE_OK(context,
534                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
535   TfLiteTensor* fw_activation_state =
536       GetVariableInput(context, node, kFwInputActivationStateTensor);
537   TF_LITE_ENSURE(context, fw_activation_state != nullptr);
538   TfLiteTensor* fw_cell_state =
539       GetVariableInput(context, node, kFwInputCellStateTensor);
540   TF_LITE_ENSURE(context, fw_cell_state != nullptr);
541 
542   // Check the shape of input state tensors.
543   // These tensor may be 1D or 2D. It's fine as long as the total size is
544   // correct.
545   TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
546                     n_batch * n_fw_output);
547   TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
548 
549   // Resize the output tensors.
550   TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
551   fw_output_size->data[0] = time_major ? max_time : n_batch;
552   fw_output_size->data[1] = time_major ? n_batch : max_time;
553   fw_output_size->data[2] =
554       params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
555   TF_LITE_ENSURE_OK(context,
556                     context->ResizeTensor(context, fw_output, fw_output_size));
557 
558   // The weights are of consistent type, so it suffices to check one.
559   const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
560 
561   TfLiteIntArrayFree(node->temporaries);
562   if (is_hybrid_op) {
563     node->temporaries = TfLiteIntArrayCreate(
564         has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1);
565   } else {
566     node->temporaries = TfLiteIntArrayCreate(2);  // the two scratch buffers.
567   }
568   // Create a scratch buffer tensor.
569   node->temporaries->data[kFwScratchBuffer] =
570       op_data->scratch_tensor_index + kFwScratchBuffer;
571   TfLiteTensor* fw_scratch_buffer;
572   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
573                                               &fw_scratch_buffer));
574   fw_scratch_buffer->type = input->type;
575   fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
576 
577   const TfLiteTensor* fw_input_to_input_weights =
578       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
579   const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
580   if (has_aux_input && !fw_use_cifg) {
581     TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
582                       fw_input_to_input_weights->dims->data[0]);
583   }
584   TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
585   fw_scratch_buffer_size->data[0] = n_batch;
586   if (fw_use_cifg) {
587     // Reserving space for Cell, Forget, Output gates
588     fw_scratch_buffer_size->data[1] = n_fw_cell * 3;
589   } else {
590     // Reserving space for Input, Cell, Forget, Output gates
591     fw_scratch_buffer_size->data[1] = n_fw_cell * 4;
592   }
593   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
594                                                    fw_scratch_buffer_size));
595   // Same for the backward cell.
596 
597   // Check that input tensor dimensions matches with each other.
598   TF_LITE_ENSURE_OK(
599       context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
600                                           n_bw_cell));
601 
602   // Get the pointer to activation_state and cell_state buffer tensors.
603   TfLiteTensor* bw_activation_state =
604       GetVariableInput(context, node, kBwInputActivationStateTensor);
605   TF_LITE_ENSURE(context, bw_activation_state != nullptr);
606   TfLiteTensor* bw_cell_state =
607       GetVariableInput(context, node, kBwInputCellStateTensor);
608   TF_LITE_ENSURE(context, bw_cell_state != nullptr);
609 
610   // Resize the output tensors.
611   if (!params->merge_outputs) {
612     TfLiteTensor* bw_output;
613     TF_LITE_ENSURE_OK(
614         context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
615     TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
616     bw_output_size->data[0] = time_major ? max_time : n_batch;
617     bw_output_size->data[1] = time_major ? n_batch : max_time;
618     bw_output_size->data[2] = n_bw_output;
619     TF_LITE_ENSURE_OK(
620         context, context->ResizeTensor(context, bw_output, bw_output_size));
621   }
622 
623   // Check the shape of input state tensors.
624   // These tensor may be 1D or 2D. It's fine as long as the total size is
625   // correct.
626   TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
627                     n_batch * n_bw_output);
628   TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
629 
630   // Create a scratch buffer tensor.
631   node->temporaries->data[kBwScratchBuffer] =
632       op_data->scratch_tensor_index + kBwScratchBuffer;
633   TfLiteTensor* bw_scratch_buffer;
634   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
635                                               &bw_scratch_buffer));
636   bw_scratch_buffer->type = input->type;
637   bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
638 
639   const TfLiteTensor* bw_input_to_input_weights =
640       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
641   const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
642   if (has_aux_input && !bw_use_cifg) {
643     TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
644                       bw_input_to_input_weights->dims->data[0]);
645   }
646   TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
647   bw_scratch_buffer_size->data[0] = n_batch;
648   if (bw_use_cifg) {
649     // Reserving space for Cell, Forget, Output gates
650     bw_scratch_buffer_size->data[1] = n_bw_cell * 3;
651   } else {
652     // Reserving space for Input, Cell, Forget, Output gates
653     bw_scratch_buffer_size->data[1] = n_bw_cell * 4;
654   }
655   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
656                                                    bw_scratch_buffer_size));
657   if (is_hybrid_op) {
658     // Compute the row sums for cached zero_point offset calculation.
659     op_data->compute_fw_row_sums = true;
660     op_data->compute_bw_row_sums = true;
661     // Allocate temporary tensors to store quantized values of input, aux_input
662     // (if present), activation_state and cell_state tensors.
663     node->temporaries->data[kInputQuantized] =
664         op_data->scratch_tensor_index + kInputQuantized;
665     TfLiteTensor* input_quantized;
666     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
667                                                 &input_quantized));
668     input_quantized->type = fw_input_to_output_weights->type;
669     input_quantized->allocation_type = kTfLiteArenaRw;
670     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
671       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
672       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
673                                                        input_quantized_size));
674     }
675 
676     node->temporaries->data[kFwActivationStateQuantized] =
677         op_data->scratch_tensor_index + kFwActivationStateQuantized;
678     TfLiteTensor* fw_activation_state_quantized;
679     TF_LITE_ENSURE_OK(
680         context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
681                                   &fw_activation_state_quantized));
682     fw_activation_state_quantized->type = fw_input_to_output_weights->type;
683     fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
684     if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
685                              fw_activation_state->dims)) {
686       TfLiteIntArray* fw_activation_state_quantized_size =
687           TfLiteIntArrayCopy(fw_activation_state->dims);
688       TF_LITE_ENSURE_OK(
689           context, context->ResizeTensor(context, fw_activation_state_quantized,
690                                          fw_activation_state_quantized_size));
691     }
692     node->temporaries->data[kBwActivationStateQuantized] =
693         op_data->scratch_tensor_index + kBwActivationStateQuantized;
694     TfLiteTensor* bw_activation_state_quantized;
695     TF_LITE_ENSURE_OK(
696         context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
697                                   &bw_activation_state_quantized));
698     bw_activation_state_quantized->type = fw_input_to_output_weights->type;
699     bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
700     if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
701                              bw_activation_state->dims)) {
702       TfLiteIntArray* bw_activation_state_quantized_size =
703           TfLiteIntArrayCopy(bw_activation_state->dims);
704       TF_LITE_ENSURE_OK(
705           context, context->ResizeTensor(context, bw_activation_state_quantized,
706                                          bw_activation_state_quantized_size));
707     }
708     node->temporaries->data[kFwCellStateQuantized] =
709         op_data->scratch_tensor_index + kFwCellStateQuantized;
710     TfLiteTensor* fw_cell_state_quantized;
711     TF_LITE_ENSURE_OK(context,
712                       GetTemporarySafe(context, node, kFwCellStateQuantized,
713                                        &fw_cell_state_quantized));
714     fw_cell_state_quantized->type = fw_input_to_output_weights->type;
715     fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
716     if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
717                              fw_cell_state->dims)) {
718       TfLiteIntArray* fw_cell_state_quantized_size =
719           TfLiteIntArrayCopy(fw_cell_state->dims);
720       TF_LITE_ENSURE_OK(context,
721                         context->ResizeTensor(context, fw_cell_state_quantized,
722                                               fw_cell_state_quantized_size));
723     }
724     node->temporaries->data[kBwCellStateQuantized] =
725         op_data->scratch_tensor_index + kBwCellStateQuantized;
726     TfLiteTensor* bw_cell_state_quantized;
727     TF_LITE_ENSURE_OK(context,
728                       GetTemporarySafe(context, node, kBwCellStateQuantized,
729                                        &bw_cell_state_quantized));
730     bw_cell_state_quantized->type = fw_input_to_output_weights->type;
731     bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
732     if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
733                              bw_cell_state->dims)) {
734       TfLiteIntArray* bw_cell_state_quantized_size =
735           TfLiteIntArrayCopy(bw_cell_state->dims);
736       TF_LITE_ENSURE_OK(context,
737                         context->ResizeTensor(context, bw_cell_state_quantized,
738                                               bw_cell_state_quantized_size));
739     }
740 
741     // Allocate temporary tensors to store scaling factors and product scaling
742     // factors. The latter is a convenience storage which allows to quantize
743     // a vector once (which produces the scaling factors) and multiply it with
744     // different matrices (which requires multiplying the scaling factors with
745     // the scaling factor of the matrix).
746     node->temporaries->data[kInputScalingFactors] =
747         op_data->scratch_tensor_index + kInputScalingFactors;
748     TfLiteTensor* input_sf;
749     TF_LITE_ENSURE_OK(
750         context,
751         GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
752     input_sf->type = kTfLiteFloat32;
753     input_sf->allocation_type = kTfLiteArenaRw;
754     int scaling_dims[1] = {n_batch};
755     if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
756       TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
757       input_sf_size->data[0] = n_batch;
758       TF_LITE_ENSURE_OK(
759           context, context->ResizeTensor(context, input_sf, input_sf_size));
760     }
761     node->temporaries->data[kAuxInputScalingFactors] =
762         op_data->scratch_tensor_index + kAuxInputScalingFactors;
763     TfLiteTensor* aux_input_sf;
764     TF_LITE_ENSURE_OK(context,
765                       GetTemporarySafe(context, node, kAuxInputScalingFactors,
766                                        &aux_input_sf));
767     aux_input_sf->type = kTfLiteFloat32;
768     aux_input_sf->allocation_type = kTfLiteArenaRw;
769     if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
770       TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
771       aux_input_sf_size->data[0] = n_batch;
772       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
773                                                        aux_input_sf_size));
774     }
775     node->temporaries->data[kOutputStateScalingFactors] =
776         op_data->scratch_tensor_index + kOutputStateScalingFactors;
777     TfLiteTensor* output_state_sf;
778     TF_LITE_ENSURE_OK(
779         context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
780                                   &output_state_sf));
781     output_state_sf->type = kTfLiteFloat32;
782     output_state_sf->allocation_type = kTfLiteArenaRw;
783     if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
784       TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
785       output_state_sf_size->data[0] = n_batch;
786       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
787                                                        output_state_sf_size));
788     }
789     node->temporaries->data[kProductScalingFactors] =
790         op_data->scratch_tensor_index + kProductScalingFactors;
791     TfLiteTensor* prod_scaling_factors;
792     TF_LITE_ENSURE_OK(context,
793                       GetTemporarySafe(context, node, kProductScalingFactors,
794                                        &prod_scaling_factors));
795     prod_scaling_factors->type = kTfLiteFloat32;
796     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
797     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
798                                    scaling_dims)) {
799       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
800       prod_scaling_factors_size->data[0] = n_batch;
801       TF_LITE_ENSURE_OK(context,
802                         context->ResizeTensor(context, prod_scaling_factors,
803                                               prod_scaling_factors_size));
804     }
805 
806     // Allocate a temporary tensor to store the recovered cell weights. Since
807     // this is used for diagonal matrices, only need to store n_cell values.
808     node->temporaries->data[kRecoveredCellWeights] =
809         op_data->scratch_tensor_index + kRecoveredCellWeights;
810     TfLiteTensor* recovered_cell_weights;
811     TF_LITE_ENSURE_OK(context,
812                       GetTemporarySafe(context, node, kRecoveredCellWeights,
813                                        &recovered_cell_weights));
814     recovered_cell_weights->type = kTfLiteFloat32;
815     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
816     int recovered_cell_dims[1] = {n_fw_cell};
817     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
818                                    recovered_cell_dims)) {
819       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
820       recovered_cell_weights_size->data[0] = n_fw_cell;
821       TF_LITE_ENSURE_OK(context,
822                         context->ResizeTensor(context, recovered_cell_weights,
823                                               recovered_cell_weights_size));
824     }
825 
826     // Allocate a temporary tensor to store the accumulated int32 values.
827     node->temporaries->data[kAccumScratchBuffer] =
828         op_data->scratch_tensor_index + kAccumScratchBuffer;
829     TfLiteTensor* accum_scratch;
830     TF_LITE_ENSURE_OK(
831         context,
832         GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
833     accum_scratch->type = kTfLiteInt32;
834     accum_scratch->allocation_type = kTfLiteArenaRw;
835     int n_cell = std::max(n_fw_cell, n_bw_cell);
836     if (has_aux_input) {
837       n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
838       n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
839     }
840     int accum_scratch_dims[2] = {n_cell, n_batch};
841     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
842                                    accum_scratch_dims)) {
843       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
844       accum_size->data[0] = n_cell;
845       accum_size->data[1] = n_batch;
846       TF_LITE_ENSURE_OK(
847           context, context->ResizeTensor(context, accum_scratch, accum_size));
848     }
849 
850     // Allocate temporary tensors for storing zero-points.
851     node->temporaries->data[kInputZeroPoints] =
852         op_data->scratch_tensor_index + kInputZeroPoints;
853     TfLiteTensor* input_zp;
854     TF_LITE_ENSURE_OK(
855         context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
856     input_zp->type = kTfLiteFloat32;
857     input_zp->allocation_type = kTfLiteArenaRw;
858     if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
859       TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
860       input_zp_size->data[0] = n_batch;
861       TF_LITE_ENSURE_OK(
862           context, context->ResizeTensor(context, input_zp, input_zp_size));
863     }
864     node->temporaries->data[kAuxInputZeroPoints] =
865         op_data->scratch_tensor_index + kAuxInputZeroPoints;
866     TfLiteTensor* aux_input_zp;
867     TF_LITE_ENSURE_OK(
868         context,
869         GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
870     aux_input_zp->type = kTfLiteFloat32;
871     aux_input_zp->allocation_type = kTfLiteArenaRw;
872     if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
873       TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
874       aux_input_zp_size->data[0] = n_batch;
875       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
876                                                        aux_input_zp_size));
877     }
878     node->temporaries->data[kOutputStateZeroPoints] =
879         op_data->scratch_tensor_index + kOutputStateZeroPoints;
880     TfLiteTensor* output_state_zp;
881     TF_LITE_ENSURE_OK(context,
882                       GetTemporarySafe(context, node, kOutputStateZeroPoints,
883                                        &output_state_zp));
884     output_state_zp->type = kTfLiteFloat32;
885     output_state_zp->allocation_type = kTfLiteArenaRw;
886     if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
887       TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
888       output_state_zp_size->data[0] = n_batch;
889       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
890                                                        output_state_zp_size));
891     }
892 
893     // Allocate temporary tensors for caching row sums for hybrid zero-point
894     // calculations.
895     int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
896     if (has_aux_input) {
897       fw_row_sums_rows += fw_use_cifg ? 3 : 4;
898     }
899     const TfLiteTensor* fw_projection_weights =
900         GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
901     if (fw_projection_weights != nullptr) {
902       fw_row_sums_rows += ceil(static_cast<float>(n_fw_output) / n_fw_cell);
903     }
904     node->temporaries->data[kFwRowSums] =
905         op_data->scratch_tensor_index + kFwRowSums;
906     TfLiteTensor* fw_row_sums;
907     TF_LITE_ENSURE_OK(
908         context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
909     fw_row_sums->type = kTfLiteInt32;
910     fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
911     int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
912     if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
913       TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
914       fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
915       fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
916       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
917                                                        fw_hybrid_scratch_size));
918     }
919 
920     int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
921     if (has_aux_input) {
922       bw_row_sums_rows += bw_use_cifg ? 3 : 4;
923     }
924     const TfLiteTensor* bw_projection_weights =
925         GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
926     if (bw_projection_weights != nullptr) {
927       bw_row_sums_rows += ceil(static_cast<float>(n_bw_output) / n_bw_cell);
928     }
929     node->temporaries->data[kBwRowSums] =
930         op_data->scratch_tensor_index + kBwRowSums;
931     TfLiteTensor* bw_row_sums;
932     TF_LITE_ENSURE_OK(
933         context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
934     bw_row_sums->type = kTfLiteInt32;
935     bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
936     int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
937     if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
938       TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
939       bw_row_sums_size->data[0] = bw_row_sums_dims[0];
940       bw_row_sums_size->data[1] = bw_row_sums_dims[1];
941       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
942                                                        bw_row_sums_size));
943     }
944 
945     // Only allocate a temporary tensor for quantized auxiliary input if we are
946     // actually going to use it.
947     if (has_aux_input) {
948       node->temporaries->data[kAuxInputQuantized] =
949           op_data->scratch_tensor_index + kAuxInputQuantized;
950       TfLiteTensor* aux_input_quantized;
951       TF_LITE_ENSURE_OK(context,
952                         GetTemporarySafe(context, node, kAuxInputQuantized,
953                                          &aux_input_quantized));
954       aux_input_quantized->type = fw_input_to_output_weights->type;
955       aux_input_quantized->allocation_type = kTfLiteArenaRw;
956       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
957         TfLiteIntArray* aux_input_quantized_size =
958             TfLiteIntArrayCopy(aux_input->dims);
959         TF_LITE_ENSURE_OK(context,
960                           context->ResizeTensor(context, aux_input_quantized,
961                                                 aux_input_quantized_size));
962       }
963     }
964   }
965   return kTfLiteOk;
966 }
967 
968 // The LSTM Op engine.
Eval(TfLiteContext * context,TfLiteNode * node)969 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
970   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
971       node->builtin_data);
972   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
973   // Input tensor.
974   const TfLiteTensor* input;
975   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
976 
977   // Tensors for the forward cell.
978   const TfLiteTensor* fw_input_to_input_weights =
979       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
980   const TfLiteTensor* fw_input_to_forget_weights;
981   TF_LITE_ENSURE_OK(context,
982                     GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
983                                  &fw_input_to_forget_weights));
984   const TfLiteTensor* fw_input_to_cell_weights;
985   TF_LITE_ENSURE_OK(context,
986                     GetInputSafe(context, node, kFwInputToCellWeightsTensor,
987                                  &fw_input_to_cell_weights));
988   const TfLiteTensor* fw_input_to_output_weights;
989   TF_LITE_ENSURE_OK(context,
990                     GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
991                                  &fw_input_to_output_weights));
992 
993   const TfLiteTensor* fw_recurrent_to_input_weights =
994       GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
995   const TfLiteTensor* fw_recurrent_to_forget_weights;
996   TF_LITE_ENSURE_OK(
997       context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
998                             &fw_recurrent_to_forget_weights));
999   const TfLiteTensor* fw_recurrent_to_cell_weights;
1000   TF_LITE_ENSURE_OK(context,
1001                     GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
1002                                  &fw_recurrent_to_cell_weights));
1003   const TfLiteTensor* fw_recurrent_to_output_weights;
1004   TF_LITE_ENSURE_OK(
1005       context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
1006                             &fw_recurrent_to_output_weights));
1007 
1008   const TfLiteTensor* fw_cell_to_input_weights =
1009       GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
1010   const TfLiteTensor* fw_cell_to_forget_weights =
1011       GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor);
1012   const TfLiteTensor* fw_cell_to_output_weights =
1013       GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor);
1014 
1015   const TfLiteTensor* fw_input_gate_bias =
1016       GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
1017   const TfLiteTensor* fw_forget_gate_bias;
1018   TF_LITE_ENSURE_OK(context,
1019                     GetInputSafe(context, node, kFwForgetGateBiasTensor,
1020                                  &fw_forget_gate_bias));
1021   const TfLiteTensor* fw_cell_gate_bias;
1022   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
1023                                           &fw_cell_gate_bias));
1024   const TfLiteTensor* fw_output_gate_bias;
1025   TF_LITE_ENSURE_OK(context,
1026                     GetInputSafe(context, node, kFwOutputGateBiasTensor,
1027                                  &fw_output_gate_bias));
1028 
1029   const TfLiteTensor* fw_projection_weights =
1030       GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
1031   const TfLiteTensor* fw_projection_bias =
1032       GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
1033 
1034   TfLiteTensor* fw_activation_state =
1035       GetVariableInput(context, node, kFwInputActivationStateTensor);
1036   TFLITE_DCHECK(fw_activation_state != nullptr);
1037   TfLiteTensor* fw_cell_state =
1038       GetVariableInput(context, node, kFwInputCellStateTensor);
1039   TFLITE_DCHECK(fw_cell_state != nullptr);
1040   TfLiteTensor* fw_output;
1041   TF_LITE_ENSURE_OK(context,
1042                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
1043 
1044   // Tensors for the backward cell.
1045   const TfLiteTensor* bw_input_to_input_weights =
1046       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
1047   const TfLiteTensor* bw_input_to_forget_weights;
1048   TF_LITE_ENSURE_OK(context,
1049                     GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
1050                                  &bw_input_to_forget_weights));
1051   const TfLiteTensor* bw_input_to_cell_weights;
1052   TF_LITE_ENSURE_OK(context,
1053                     GetInputSafe(context, node, kBwInputToCellWeightsTensor,
1054                                  &bw_input_to_cell_weights));
1055   const TfLiteTensor* bw_input_to_output_weights;
1056   TF_LITE_ENSURE_OK(context,
1057                     GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
1058                                  &bw_input_to_output_weights));
1059 
1060   const TfLiteTensor* bw_recurrent_to_input_weights =
1061       GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
1062   const TfLiteTensor* bw_recurrent_to_forget_weights;
1063   TF_LITE_ENSURE_OK(
1064       context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
1065                             &bw_recurrent_to_forget_weights));
1066   const TfLiteTensor* bw_recurrent_to_cell_weights;
1067   TF_LITE_ENSURE_OK(context,
1068                     GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
1069                                  &bw_recurrent_to_cell_weights));
1070   const TfLiteTensor* bw_recurrent_to_output_weights;
1071   TF_LITE_ENSURE_OK(
1072       context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
1073                             &bw_recurrent_to_output_weights));
1074 
1075   const TfLiteTensor* bw_cell_to_input_weights =
1076       GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
1077   const TfLiteTensor* bw_cell_to_forget_weights =
1078       GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor);
1079   const TfLiteTensor* bw_cell_to_output_weights =
1080       GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor);
1081 
1082   const TfLiteTensor* bw_input_gate_bias =
1083       GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
1084   const TfLiteTensor* bw_forget_gate_bias;
1085   TF_LITE_ENSURE_OK(context,
1086                     GetInputSafe(context, node, kBwForgetGateBiasTensor,
1087                                  &bw_forget_gate_bias));
1088   const TfLiteTensor* bw_cell_gate_bias;
1089   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
1090                                           &bw_cell_gate_bias));
1091   const TfLiteTensor* bw_output_gate_bias;
1092   TF_LITE_ENSURE_OK(context,
1093                     GetInputSafe(context, node, kBwOutputGateBiasTensor,
1094                                  &bw_output_gate_bias));
1095 
1096   const TfLiteTensor* bw_projection_weights =
1097       GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
1098   const TfLiteTensor* bw_projection_bias =
1099       GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
1100 
1101   // State tensors.
1102   TfLiteTensor* bw_activation_state =
1103       GetVariableInput(context, node, kBwInputActivationStateTensor);
1104   TFLITE_DCHECK(bw_activation_state != nullptr);
1105   TfLiteTensor* bw_cell_state =
1106       GetVariableInput(context, node, kBwInputCellStateTensor);
1107   TFLITE_DCHECK(bw_cell_state != nullptr);
1108   TfLiteTensor* bw_output = params->merge_outputs
1109                                 ? nullptr
1110                                 : GetOutput(context, node, kBwOutputTensor);
1111 
1112   // Temporary tensors.
1113   TfLiteTensor* fw_scratch_buffer;
1114   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
1115                                               &fw_scratch_buffer));
1116   TfLiteTensor* bw_scratch_buffer;
1117   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
1118                                               &bw_scratch_buffer));
1119 
1120   // (Optional) auxiliary inputs.
1121   const TfLiteTensor* aux_input =
1122       GetOptionalInputTensor(context, node, kAuxInputTensor);
1123   const TfLiteTensor* fw_aux_input_to_input_weights =
1124       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
1125   const TfLiteTensor* fw_aux_input_to_forget_weights =
1126       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
1127   const TfLiteTensor* fw_aux_input_to_cell_weights =
1128       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
1129   const TfLiteTensor* fw_aux_input_to_output_weights =
1130       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
1131   const TfLiteTensor* bw_aux_input_to_input_weights =
1132       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
1133   const TfLiteTensor* bw_aux_input_to_forget_weights =
1134       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
1135   const TfLiteTensor* bw_aux_input_to_cell_weights =
1136       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
1137   const TfLiteTensor* bw_aux_input_to_output_weights =
1138       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
1139 
1140   const bool has_previous_bw_output = (aux_input != nullptr);
1141   const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
1142 
1143   // Populate a TfLiteLSTMParams struct for the evaluation functions.
1144   TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
1145                                   params->proj_clip, kTfLiteLSTMFullKernel,
1146                                   params->asymmetric_quantize_inputs};
1147 
1148   const int bw_output_offset =
1149       params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
1150   const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
1151 
1152   const bool time_major = params->time_major;
1153 
1154   // We want to cover the following cases:
1155   //
1156   // If not stacking (not connected after other bidi lstms):
1157   //   both fw & bw will just use `input`; aux_input will be null.
1158   //
1159   // If stacking with cross_links, TensorFlow equivalent
1160   // (tf.contrib.rnn.stack_bidirectional_rnn):
1161   //   both fw & bw will use `input`, but aux_input will be none null.
1162   //   Note, this time, whether connected after other bidi lstms both works.
1163   //
1164   // If stacking without cross_links, but connected after other bidi lstms,
1165   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
1166   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
1167   //   will be null.
1168 
1169   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
1170   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
1171   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
1172 
1173   switch (fw_input_to_output_weights->type) {
1174     case kTfLiteFloat32: {
1175       TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
1176           input, fw_input_to_input_weights, fw_input_to_forget_weights,
1177           fw_input_to_cell_weights, fw_input_to_output_weights,
1178           fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
1179           fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
1180           fw_cell_to_input_weights, fw_cell_to_forget_weights,
1181           fw_cell_to_output_weights,
1182           /*input_layer_norm_coefficients=*/nullptr,
1183           /*forget_layer_norm_coefficients=*/nullptr,
1184           /*cell_layer_norm_coefficients=*/nullptr,
1185           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1186           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1187           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1188           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1189           fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
1190           &lstm_params,
1191           /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1192           fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
1193       TF_LITE_ENSURE_OK(context, fw_pass_status);
1194 
1195       TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
1196           bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
1197           bw_input_to_cell_weights, bw_input_to_output_weights,
1198           bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
1199           bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
1200           bw_cell_to_input_weights, bw_cell_to_forget_weights,
1201           bw_cell_to_output_weights,
1202           /*input_layer_norm_coefficients=*/nullptr,
1203           /*forget_layer_norm_coefficients=*/nullptr,
1204           /*cell_layer_norm_coefficients=*/nullptr,
1205           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1206           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1207           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1208           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1209           bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
1210           &lstm_params,
1211           /*forward_sequence=*/false, time_major, bw_output_offset,
1212           bw_scratch_buffer, bw_activation_state, bw_cell_state,
1213           actual_bw_output);
1214       TF_LITE_ENSURE_OK(context, bw_pass_status);
1215       return kTfLiteOk;
1216     }
1217     case kTfLiteUInt8:
1218     case kTfLiteInt8: {
1219       TfLiteTensor* input_quantized;
1220       TF_LITE_ENSURE_OK(
1221           context,
1222           GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
1223       TfLiteTensor* fw_activation_state_quantized;
1224       TF_LITE_ENSURE_OK(
1225           context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
1226                                     &fw_activation_state_quantized));
1227       TfLiteTensor* bw_activation_state_quantized;
1228       TF_LITE_ENSURE_OK(
1229           context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
1230                                     &bw_activation_state_quantized));
1231       TfLiteTensor* fw_cell_state_quantized;
1232       TF_LITE_ENSURE_OK(context,
1233                         GetTemporarySafe(context, node, kFwCellStateQuantized,
1234                                          &fw_cell_state_quantized));
1235       TfLiteTensor* bw_cell_state_quantized;
1236       TF_LITE_ENSURE_OK(context,
1237                         GetTemporarySafe(context, node, kBwCellStateQuantized,
1238                                          &bw_cell_state_quantized));
1239       TfLiteTensor* prod_scaling_factors;
1240       TF_LITE_ENSURE_OK(context,
1241                         GetTemporarySafe(context, node, kProductScalingFactors,
1242                                          &prod_scaling_factors));
1243       TfLiteTensor* recovered_cell_weights;
1244       TF_LITE_ENSURE_OK(context,
1245                         GetTemporarySafe(context, node, kRecoveredCellWeights,
1246                                          &recovered_cell_weights));
1247       TfLiteTensor* aux_input_quantized =
1248           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
1249                         : nullptr;
1250       TfLiteTensor* accum_scratch;
1251       TF_LITE_ENSURE_OK(
1252           context,
1253           GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
1254       TfLiteTensor* fw_row_sums;
1255       TF_LITE_ENSURE_OK(
1256           context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
1257       TfLiteTensor* bw_row_sums;
1258       TF_LITE_ENSURE_OK(
1259           context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
1260       const int fw_row_sums_size = fw_row_sums->dims->data[0];
1261       const int bw_row_sums_size = bw_row_sums->dims->data[0];
1262       TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
1263           input, fw_input_to_input_weights,
1264           /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
1265           /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
1266           /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
1267           /*input_to_output_weights_ledger*/ nullptr,
1268           fw_recurrent_to_input_weights,
1269           /*recurrent_to_input_weights_ledger*/ nullptr,
1270           fw_recurrent_to_forget_weights,
1271           /*recurrent_to_forget_weights_ledger*/ nullptr,
1272           fw_recurrent_to_cell_weights,
1273           /*recurrent_to_cell_weights_ledger*/ nullptr,
1274           fw_recurrent_to_output_weights,
1275           /*recurrent_to_output_weights_ledger*/ nullptr,
1276           fw_cell_to_input_weights, fw_cell_to_forget_weights,
1277           fw_cell_to_output_weights,
1278           /*input_layer_norm_coefficients=*/nullptr,
1279           /*forget_layer_norm_coefficients=*/nullptr,
1280           /*cell_layer_norm_coefficients=*/nullptr,
1281           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1282           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1283           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1284           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1285           fw_output_gate_bias, fw_projection_weights,
1286           /*projection_weights_ledger*/ nullptr, fw_projection_bias,
1287           &lstm_params,
1288           /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
1289           fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1290           GetTemporary(context, node, kAuxInputScalingFactors),
1291           GetTemporary(context, node, kOutputStateScalingFactors),
1292           prod_scaling_factors, recovered_cell_weights, input_quantized,
1293           aux_input_quantized, fw_activation_state_quantized,
1294           fw_cell_state_quantized, fw_activation_state, fw_cell_state,
1295           accum_scratch, fw_output,
1296           GetTemporary(context, node, kInputZeroPoints),
1297           GetTemporary(context, node, kAuxInputZeroPoints),
1298           GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
1299           fw_row_sums_size, &op_data->compute_fw_row_sums,
1300           CpuBackendContext::GetFromContext(context));
1301       TF_LITE_ENSURE_OK(context, fw_pass_status);
1302 
1303       TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
1304           bw_input, bw_input_to_input_weights,
1305           /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
1306           /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
1307           /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
1308           /*input_to_output_weights_ledger*/ nullptr,
1309           bw_recurrent_to_input_weights,
1310           /*recurrent_to_input_weights_ledger*/ nullptr,
1311           bw_recurrent_to_forget_weights,
1312           /*recurrent_to_forget_weights_ledger*/ nullptr,
1313           bw_recurrent_to_cell_weights,
1314           /*recurrent_to_cell_weights_ledger*/ nullptr,
1315           bw_recurrent_to_output_weights,
1316           /*recurrent_to_output_weights_ledger*/ nullptr,
1317           bw_cell_to_input_weights, bw_cell_to_forget_weights,
1318           bw_cell_to_output_weights,
1319           /*input_layer_norm_coefficients=*/nullptr,
1320           /*forget_layer_norm_coefficients=*/nullptr,
1321           /*cell_layer_norm_coefficients=*/nullptr,
1322           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1323           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1324           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1325           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1326           bw_output_gate_bias, bw_projection_weights,
1327           /*projection_weights_ledger*/ nullptr, bw_projection_bias,
1328           &lstm_params,
1329           /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
1330           bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1331           GetTemporary(context, node, kAuxInputScalingFactors),
1332           GetTemporary(context, node, kOutputStateScalingFactors),
1333           prod_scaling_factors, recovered_cell_weights, input_quantized,
1334           aux_input_quantized, bw_activation_state_quantized,
1335           bw_cell_state_quantized, bw_activation_state, bw_cell_state,
1336           accum_scratch, actual_bw_output,
1337           GetTemporary(context, node, kInputZeroPoints),
1338           GetTemporary(context, node, kAuxInputZeroPoints),
1339           GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
1340           bw_row_sums_size, &op_data->compute_bw_row_sums,
1341           CpuBackendContext::GetFromContext(context));
1342       TF_LITE_ENSURE_OK(context, bw_pass_status);
1343       return kTfLiteOk;
1344     }
1345     default:
1346       TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1347                          TfLiteTypeGetName(fw_input_to_output_weights->type));
1348       return kTfLiteError;
1349   }
1350   return kTfLiteOk;
1351 }
1352 
1353 }  // namespace bidirectional_sequence_lstm
1354 
Register_BIDIRECTIONAL_SEQUENCE_LSTM()1355 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
1356   static TfLiteRegistration r = {
1357       bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
1358       bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
1359   return &r;
1360 }
1361 
1362 }  // namespace builtin
1363 }  // namespace ops
1364 }  // namespace tflite
1365