1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <math.h>
17
18 #include <algorithm>
19 #include <cstddef>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
26 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/lstm_eval.h"
29 #include "tensorflow/lite/kernels/op_macros.h"
30
31 namespace tflite {
32 namespace ops {
33 namespace builtin {
34 namespace bidirectional_sequence_lstm {
35
36 // LINT.IfChange
37
38 // Input Tensors of size {max_time, n_batch, n_input}
39 constexpr int kInputTensor = 0;
40
41 // Forward LSTM cell tensors.
42 // Input weight tensors of size: {n_cell, n_input}
43 constexpr int kFwInputToInputWeightsTensor = 1; // Optional
44 constexpr int kFwInputToForgetWeightsTensor = 2;
45 constexpr int kFwInputToCellWeightsTensor = 3;
46 constexpr int kFwInputToOutputWeightsTensor = 4;
47
48 // Recurrent weight tensors of size {n_cell, n_output}
49 constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional
50 constexpr int kFwRecurrentToForgetWeightsTensor = 6;
51 constexpr int kFwRecurrentToCellWeightsTensor = 7;
52 constexpr int kFwRecurrentToOutputWeightsTensor = 8;
53
54 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
55 constexpr int kFwCellToInputWeightsTensor = 9; // Optional
56 constexpr int kFwCellToForgetWeightsTensor = 10; // Optional
57 constexpr int kFwCellToOutputWeightsTensor = 11; // Optional
58
59 // Gates bias tensors of size {n_cell}
60 constexpr int kFwInputGateBiasTensor = 12; // Optional
61 constexpr int kFwForgetGateBiasTensor = 13;
62 constexpr int kFwCellGateBiasTensor = 14;
63 constexpr int kFwOutputGateBiasTensor = 15;
64
65 // Projection weight tensor of size {n_output, n_cell}
66 constexpr int kFwProjectionWeightsTensor = 16; // Optional
67 // Projection bias tensor of size {n_output}
68 constexpr int kFwProjectionBiasTensor = 17; // Optional
69
70 // Backward LSTM cell tensors.
71 // Input weight tensors of size: {n_cell, n_input}
72 constexpr int kBwInputToInputWeightsTensor = 18; // Optional
73 constexpr int kBwInputToForgetWeightsTensor = 19;
74 constexpr int kBwInputToCellWeightsTensor = 20;
75 constexpr int kBwInputToOutputWeightsTensor = 21;
76
77 // Recurrent weight tensors of size {n_cell, n_output}
78 constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional
79 constexpr int kBwRecurrentToForgetWeightsTensor = 23;
80 constexpr int kBwRecurrentToCellWeightsTensor = 24;
81 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
82
83 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
84 constexpr int kBwCellToInputWeightsTensor = 26; // Optional
85 constexpr int kBwCellToForgetWeightsTensor = 27; // Optional
86 constexpr int kBwCellToOutputWeightsTensor = 28; // Optional
87
88 // Gates bias tensors of size {n_cell}
89 constexpr int kBwInputGateBiasTensor = 29; // Optional
90 constexpr int kBwForgetGateBiasTensor = 30;
91 constexpr int kBwCellGateBiasTensor = 31;
92 constexpr int kBwOutputGateBiasTensor = 32;
93
94 // Projection weight tensor of size {n_output, n_cell}
95 constexpr int kBwProjectionWeightsTensor = 33; // Optional
96 // Projection bias tensor of size {n_output}
97 constexpr int kBwProjectionBiasTensor = 34; // Optional
98
99 // Stateful input tensors that are variables and will be modified by the Op.
100 // Activation state tensors of size {n_batch, n_output}
101 constexpr int kFwInputActivationStateTensor = 35;
102 // Cell state tensors of size {n_batch, n_cell}
103 constexpr int kFwInputCellStateTensor = 36;
104 // Activation state tensors of size {n_batch, n_output}
105 constexpr int kBwInputActivationStateTensor = 37;
106 // Cell state tensors of size {n_batch, n_cell}
107 constexpr int kBwInputCellStateTensor = 38;
108
109 // Used as auxiliary input and weights when stacking for
110 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
111 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
112 // (without cross links).
113 constexpr int kAuxInputTensor = 39; // Optional
114 // Forward weights.
115 constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
116 constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
117 constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
118 constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
119 // Backward weights.
120 constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
121 constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
122 constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
123 constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
124
125 // Output tensors.
126 constexpr int kFwOutputTensor = 0;
127 constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
128
129 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
130
131 // Temporary tensors.
132 enum TemporaryTensor {
133 // Scratch buffers for input, forget, etc. gates
134 kFwScratchBuffer = 0,
135 kBwScratchBuffer = 1,
136 // Quantized tensors needed for the hybrid kernel.
137 kInputQuantized = 2,
138 kFwActivationStateQuantized = 3,
139 kBwActivationStateQuantized = 4,
140 kFwCellStateQuantized = 5,
141 kBwCellStateQuantized = 6,
142 kInputScalingFactors = 7,
143 kAuxInputScalingFactors = 8,
144 kOutputStateScalingFactors = 9,
145 kProductScalingFactors = 10,
146 kRecoveredCellWeights = 11,
147 kAccumScratchBuffer = 12,
148 kInputZeroPoints = 13,
149 kAuxInputZeroPoints = 14,
150 kOutputStateZeroPoints = 15,
151 kFwRowSums = 16,
152 kBwRowSums = 17,
153 kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input.
154 kNumTemporaryTensors = 19,
155 };
156
157 struct OpData {
158 int scratch_tensor_index;
159 bool compute_fw_row_sums = false;
160 bool compute_bw_row_sums = false;
161 };
162
Init(TfLiteContext * context,const char * buffer,size_t length)163 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
164 auto* op_data = new OpData();
165 context->AddTensors(context, kNumTemporaryTensors,
166 &op_data->scratch_tensor_index);
167 return op_data;
168 }
169
Free(TfLiteContext * context,void * buffer)170 void Free(TfLiteContext* context, void* buffer) {
171 delete reinterpret_cast<OpData*>(buffer);
172 }
173
174 // Check that input tensor dimensions matches with each other.
CheckLstmTensorDimensionsAndTypes(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,int input_to_input_weights_tensor,int input_to_forget_weights_tensor,int input_to_cell_weights_tensor,int input_to_output_weights_tensor,int recurrent_to_input_weights_tensor,int recurrent_to_forget_weights_tensor,int recurrent_to_cell_weights_tensor,int recurrent_to_output_weights_tensor,int cell_to_input_weights_tensor,int cell_to_forget_weights_tensor,int cell_to_output_weights_tensor,int input_gate_bias_tensor,int forget_gate_bias_tensor,int cell_gate_bias_tensor,int output_gate_bias_tensor,int projection_weights_tensor,int projection_bias_tensor)175 TfLiteStatus CheckLstmTensorDimensionsAndTypes(
176 TfLiteContext* context, TfLiteNode* node, int n_input, int n_output,
177 int n_cell, int input_to_input_weights_tensor,
178 int input_to_forget_weights_tensor, int input_to_cell_weights_tensor,
179 int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor,
180 int recurrent_to_forget_weights_tensor,
181 int recurrent_to_cell_weights_tensor,
182 int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor,
183 int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor,
184 int input_gate_bias_tensor, int forget_gate_bias_tensor,
185 int cell_gate_bias_tensor, int output_gate_bias_tensor,
186 int projection_weights_tensor, int projection_bias_tensor) {
187 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
188 node->builtin_data);
189
190 // Making sure clipping parameters have valid values.
191 // == 0 means no clipping
192 // > 0 means clipping
193 TF_LITE_ENSURE(context, params->cell_clip >= 0);
194 TF_LITE_ENSURE(context, params->proj_clip >= 0);
195
196 const TfLiteTensor* input_to_forget_weights;
197 TF_LITE_ENSURE_OK(context,
198 GetInputSafe(context, node, input_to_forget_weights_tensor,
199 &input_to_forget_weights));
200 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
201 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
202 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
203 TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
204 (input_to_forget_weights->type == kTfLiteInt8) ||
205 (input_to_forget_weights->type == kTfLiteUInt8));
206
207 const TfLiteTensor* input_to_input_weights =
208 GetOptionalInputTensor(context, node, input_to_input_weights_tensor);
209 if (input_to_input_weights != nullptr) {
210 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
211 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
212 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
213 TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
214 input_to_forget_weights->type);
215 }
216
217 const TfLiteTensor* input_to_cell_weights;
218 TF_LITE_ENSURE_OK(context,
219 GetInputSafe(context, node, input_to_cell_weights_tensor,
220 &input_to_cell_weights));
221 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
222 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
223 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
224 TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
225 input_to_forget_weights->type);
226
227 const TfLiteTensor* input_to_output_weights;
228 TF_LITE_ENSURE_OK(context,
229 GetInputSafe(context, node, input_to_output_weights_tensor,
230 &input_to_output_weights));
231 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
232 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
233 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
234 TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type,
235 input_to_forget_weights->type);
236
237 const TfLiteTensor* recurrent_to_input_weights =
238 GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor);
239 if (recurrent_to_input_weights != nullptr) {
240 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
241 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
242 n_cell);
243 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
244 n_output);
245 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
246 input_to_forget_weights->type);
247 }
248
249 const TfLiteTensor* recurrent_to_forget_weights;
250 TF_LITE_ENSURE_OK(
251 context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
252 &recurrent_to_forget_weights));
253 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
254 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
255 n_cell);
256 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
257 n_output);
258 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
259 input_to_forget_weights->type);
260
261 const TfLiteTensor* recurrent_to_cell_weights;
262 TF_LITE_ENSURE_OK(
263 context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
264 &recurrent_to_cell_weights));
265 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
266 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
267 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
268 n_output);
269 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
270 input_to_forget_weights->type);
271
272 // We make sure the input-gate's parameters are either both present (regular
273 // LSTM) or not at all (CIFG-LSTM).
274 const bool cifg_weights_all_or_none =
275 ((input_to_input_weights != nullptr) &&
276 (recurrent_to_input_weights != nullptr)) ||
277 ((input_to_input_weights == nullptr) &&
278 (recurrent_to_input_weights == nullptr));
279 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
280
281 const TfLiteTensor* cell_to_input_weights =
282 GetOptionalInputTensor(context, node, cell_to_input_weights_tensor);
283 if (cell_to_input_weights != nullptr) {
284 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
285 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
286 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type,
287 input_to_forget_weights->type);
288 }
289
290 const TfLiteTensor* cell_to_forget_weights =
291 GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor);
292 if (cell_to_forget_weights != nullptr) {
293 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
294 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
295 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type,
296 input_to_forget_weights->type);
297 }
298
299 const TfLiteTensor* cell_to_output_weights =
300 GetOptionalInputTensor(context, node, cell_to_output_weights_tensor);
301 if (cell_to_output_weights != nullptr) {
302 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
303 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
304 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type,
305 input_to_forget_weights->type);
306 }
307
308 // Making sure the peephole weights are there all or none.
309 const bool use_cifg = (input_to_input_weights == nullptr);
310 const bool peephole_weights_all_or_none =
311 ((cell_to_input_weights != nullptr || use_cifg) &&
312 (cell_to_forget_weights != nullptr) &&
313 (cell_to_output_weights != nullptr)) ||
314 ((cell_to_input_weights == nullptr) &&
315 (cell_to_forget_weights == nullptr) &&
316 (cell_to_output_weights == nullptr));
317 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
318
319 // Make sure the input gate bias is present only when not a CIFG-LSTM.
320 const TfLiteTensor* input_gate_bias =
321 GetOptionalInputTensor(context, node, input_gate_bias_tensor);
322 if (use_cifg) {
323 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
324 } else {
325 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
326 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
327 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
328 }
329
330 const TfLiteTensor* forget_gate_bias;
331 TF_LITE_ENSURE_OK(
332 context,
333 GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
334 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
335 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
336 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
337
338 const TfLiteTensor* cell_gate_bias;
339 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
340 &cell_gate_bias));
341 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
342 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
343 TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
344
345 const TfLiteTensor* output_gate_bias;
346 TF_LITE_ENSURE_OK(
347 context,
348 GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
349 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
350 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
351 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
352
353 const TfLiteTensor* projection_weights =
354 GetOptionalInputTensor(context, node, projection_weights_tensor);
355 if (projection_weights != nullptr) {
356 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
357 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
358 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
359 TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
360 input_to_forget_weights->type);
361 }
362
363 const TfLiteTensor* projection_bias =
364 GetOptionalInputTensor(context, node, projection_bias_tensor);
365 if (projection_bias != nullptr) {
366 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
367 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
368 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
369 }
370
371 // Making sure the projection tensors are consistent:
372 // 1) If projection weight is not present, then projection bias should not be
373 // present.
374 // 2) If projection weight is present, then projection bias is optional.
375 // TODO(ghodrat): make sure this is correct.
376 const bool projecton_tensors_consistent =
377 ((projection_weights != nullptr) || (projection_bias == nullptr));
378 TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
379
380 return kTfLiteOk;
381 }
382
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell)383 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
384 TfLiteNode* node, int n_input,
385 int n_output, int n_cell) {
386 TF_LITE_ENSURE_OK(
387 context,
388 CheckLstmTensorDimensionsAndTypes(
389 context, node, n_input, n_output, n_cell,
390 kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
391 kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
392 kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
393 kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
394 kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
395 kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
396 kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
397 kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
398 kFwProjectionBiasTensor));
399
400 TF_LITE_ENSURE_OK(
401 context,
402 CheckLstmTensorDimensionsAndTypes(
403 context, node, n_input, n_output, n_cell,
404 kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
405 kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
406 kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
407 kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
408 kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
409 kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
410 kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
411 kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
412 kBwProjectionBiasTensor));
413
414 // Check if Forward and Backward tensors match along required dimensions.
415 return kTfLiteOk;
416 }
417
418 // Resize the output and scratch tensors based on the sizes of the input
419 // tensors. Also check that the size of the input tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)420 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
421 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
422 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
423 node->builtin_data);
424
425 // Check we have all the inputs and outputs we need.
426 TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
427 TF_LITE_ENSURE_EQ(context, node->outputs->size,
428 params->merge_outputs ? 1 : 2);
429
430 // Inferring batch size, number of outputs and sequence length and
431 // number of cells from the input tensors.
432 const TfLiteTensor* input;
433 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
434 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
435 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
436 const bool time_major = params->time_major;
437 const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
438 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
439 const int n_input = input->dims->data[2];
440
441 const TfLiteTensor* fw_input_to_output_weights;
442 TF_LITE_ENSURE_OK(context,
443 GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
444 &fw_input_to_output_weights));
445 const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
446 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
447 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
448 n_input);
449
450 const TfLiteTensor* bw_input_to_output_weights;
451 TF_LITE_ENSURE_OK(context,
452 GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
453 &bw_input_to_output_weights));
454 const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
455 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
456 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
457 n_input);
458 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
459 fw_input_to_output_weights->type);
460
461 const TfLiteTensor* fw_recurrent_to_output_weights;
462 TF_LITE_ENSURE_OK(
463 context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
464 &fw_recurrent_to_output_weights));
465 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
466 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
467 n_fw_cell);
468 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type,
469 fw_input_to_output_weights->type);
470 const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
471
472 const TfLiteTensor* bw_recurrent_to_output_weights;
473 TF_LITE_ENSURE_OK(
474 context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
475 &bw_recurrent_to_output_weights));
476 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
477 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
478 n_bw_cell);
479 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type,
480 fw_input_to_output_weights->type);
481 const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
482
483 // Check that input tensor dimensions matches with each other.
484 TF_LITE_ENSURE_OK(
485 context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
486 n_fw_cell));
487
488 // Get (optional) auxiliary inputs and weights.
489 const TfLiteTensor* aux_input =
490 GetOptionalInputTensor(context, node, kAuxInputTensor);
491 const TfLiteTensor* fw_aux_input_to_input_weights =
492 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
493 const TfLiteTensor* fw_aux_input_to_forget_weights =
494 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
495 const TfLiteTensor* fw_aux_input_to_cell_weights =
496 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
497 const TfLiteTensor* fw_aux_input_to_output_weights =
498 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
499 const TfLiteTensor* bw_aux_input_to_input_weights =
500 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
501 const TfLiteTensor* bw_aux_input_to_forget_weights =
502 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
503 const TfLiteTensor* bw_aux_input_to_cell_weights =
504 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
505 const TfLiteTensor* bw_aux_input_to_output_weights =
506 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
507
508 const bool aux_inputs_weights_all_or_none =
509 ((fw_aux_input_to_cell_weights != nullptr) &&
510 (fw_aux_input_to_forget_weights != nullptr) &&
511 (fw_aux_input_to_output_weights != nullptr) &&
512 (bw_aux_input_to_cell_weights != nullptr) &&
513 (bw_aux_input_to_forget_weights != nullptr) &&
514 (bw_aux_input_to_output_weights != nullptr)) ||
515 ((fw_aux_input_to_cell_weights == nullptr) &&
516 (fw_aux_input_to_forget_weights == nullptr) &&
517 (fw_aux_input_to_output_weights == nullptr) &&
518 (bw_aux_input_to_cell_weights == nullptr) &&
519 (bw_aux_input_to_forget_weights == nullptr) &&
520 (bw_aux_input_to_output_weights == nullptr));
521 TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
522
523 const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
524
525 if (has_aux_input) {
526 // Check that aux_input has the same dimensions (except last) as the input.
527 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
528 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
529 }
530
531 // Get the pointer to output, activation_state and cell_state buffer tensors.
532 TfLiteTensor* fw_output;
533 TF_LITE_ENSURE_OK(context,
534 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
535 TfLiteTensor* fw_activation_state =
536 GetVariableInput(context, node, kFwInputActivationStateTensor);
537 TF_LITE_ENSURE(context, fw_activation_state != nullptr);
538 TfLiteTensor* fw_cell_state =
539 GetVariableInput(context, node, kFwInputCellStateTensor);
540 TF_LITE_ENSURE(context, fw_cell_state != nullptr);
541
542 // Check the shape of input state tensors.
543 // These tensor may be 1D or 2D. It's fine as long as the total size is
544 // correct.
545 TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
546 n_batch * n_fw_output);
547 TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
548
549 // Resize the output tensors.
550 TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
551 fw_output_size->data[0] = time_major ? max_time : n_batch;
552 fw_output_size->data[1] = time_major ? n_batch : max_time;
553 fw_output_size->data[2] =
554 params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
555 TF_LITE_ENSURE_OK(context,
556 context->ResizeTensor(context, fw_output, fw_output_size));
557
558 // The weights are of consistent type, so it suffices to check one.
559 const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
560
561 TfLiteIntArrayFree(node->temporaries);
562 if (is_hybrid_op) {
563 node->temporaries = TfLiteIntArrayCreate(
564 has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1);
565 } else {
566 node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
567 }
568 // Create a scratch buffer tensor.
569 node->temporaries->data[kFwScratchBuffer] =
570 op_data->scratch_tensor_index + kFwScratchBuffer;
571 TfLiteTensor* fw_scratch_buffer;
572 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
573 &fw_scratch_buffer));
574 fw_scratch_buffer->type = input->type;
575 fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
576
577 const TfLiteTensor* fw_input_to_input_weights =
578 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
579 const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
580 if (has_aux_input && !fw_use_cifg) {
581 TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
582 fw_input_to_input_weights->dims->data[0]);
583 }
584 TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
585 fw_scratch_buffer_size->data[0] = n_batch;
586 if (fw_use_cifg) {
587 // Reserving space for Cell, Forget, Output gates
588 fw_scratch_buffer_size->data[1] = n_fw_cell * 3;
589 } else {
590 // Reserving space for Input, Cell, Forget, Output gates
591 fw_scratch_buffer_size->data[1] = n_fw_cell * 4;
592 }
593 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
594 fw_scratch_buffer_size));
595 // Same for the backward cell.
596
597 // Check that input tensor dimensions matches with each other.
598 TF_LITE_ENSURE_OK(
599 context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
600 n_bw_cell));
601
602 // Get the pointer to activation_state and cell_state buffer tensors.
603 TfLiteTensor* bw_activation_state =
604 GetVariableInput(context, node, kBwInputActivationStateTensor);
605 TF_LITE_ENSURE(context, bw_activation_state != nullptr);
606 TfLiteTensor* bw_cell_state =
607 GetVariableInput(context, node, kBwInputCellStateTensor);
608 TF_LITE_ENSURE(context, bw_cell_state != nullptr);
609
610 // Resize the output tensors.
611 if (!params->merge_outputs) {
612 TfLiteTensor* bw_output;
613 TF_LITE_ENSURE_OK(
614 context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
615 TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
616 bw_output_size->data[0] = time_major ? max_time : n_batch;
617 bw_output_size->data[1] = time_major ? n_batch : max_time;
618 bw_output_size->data[2] = n_bw_output;
619 TF_LITE_ENSURE_OK(
620 context, context->ResizeTensor(context, bw_output, bw_output_size));
621 }
622
623 // Check the shape of input state tensors.
624 // These tensor may be 1D or 2D. It's fine as long as the total size is
625 // correct.
626 TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
627 n_batch * n_bw_output);
628 TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
629
630 // Create a scratch buffer tensor.
631 node->temporaries->data[kBwScratchBuffer] =
632 op_data->scratch_tensor_index + kBwScratchBuffer;
633 TfLiteTensor* bw_scratch_buffer;
634 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
635 &bw_scratch_buffer));
636 bw_scratch_buffer->type = input->type;
637 bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
638
639 const TfLiteTensor* bw_input_to_input_weights =
640 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
641 const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
642 if (has_aux_input && !bw_use_cifg) {
643 TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
644 bw_input_to_input_weights->dims->data[0]);
645 }
646 TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
647 bw_scratch_buffer_size->data[0] = n_batch;
648 if (bw_use_cifg) {
649 // Reserving space for Cell, Forget, Output gates
650 bw_scratch_buffer_size->data[1] = n_bw_cell * 3;
651 } else {
652 // Reserving space for Input, Cell, Forget, Output gates
653 bw_scratch_buffer_size->data[1] = n_bw_cell * 4;
654 }
655 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
656 bw_scratch_buffer_size));
657 if (is_hybrid_op) {
658 // Compute the row sums for cached zero_point offset calculation.
659 op_data->compute_fw_row_sums = true;
660 op_data->compute_bw_row_sums = true;
661 // Allocate temporary tensors to store quantized values of input, aux_input
662 // (if present), activation_state and cell_state tensors.
663 node->temporaries->data[kInputQuantized] =
664 op_data->scratch_tensor_index + kInputQuantized;
665 TfLiteTensor* input_quantized;
666 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
667 &input_quantized));
668 input_quantized->type = fw_input_to_output_weights->type;
669 input_quantized->allocation_type = kTfLiteArenaRw;
670 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
671 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
672 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
673 input_quantized_size));
674 }
675
676 node->temporaries->data[kFwActivationStateQuantized] =
677 op_data->scratch_tensor_index + kFwActivationStateQuantized;
678 TfLiteTensor* fw_activation_state_quantized;
679 TF_LITE_ENSURE_OK(
680 context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
681 &fw_activation_state_quantized));
682 fw_activation_state_quantized->type = fw_input_to_output_weights->type;
683 fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
684 if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
685 fw_activation_state->dims)) {
686 TfLiteIntArray* fw_activation_state_quantized_size =
687 TfLiteIntArrayCopy(fw_activation_state->dims);
688 TF_LITE_ENSURE_OK(
689 context, context->ResizeTensor(context, fw_activation_state_quantized,
690 fw_activation_state_quantized_size));
691 }
692 node->temporaries->data[kBwActivationStateQuantized] =
693 op_data->scratch_tensor_index + kBwActivationStateQuantized;
694 TfLiteTensor* bw_activation_state_quantized;
695 TF_LITE_ENSURE_OK(
696 context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
697 &bw_activation_state_quantized));
698 bw_activation_state_quantized->type = fw_input_to_output_weights->type;
699 bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
700 if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
701 bw_activation_state->dims)) {
702 TfLiteIntArray* bw_activation_state_quantized_size =
703 TfLiteIntArrayCopy(bw_activation_state->dims);
704 TF_LITE_ENSURE_OK(
705 context, context->ResizeTensor(context, bw_activation_state_quantized,
706 bw_activation_state_quantized_size));
707 }
708 node->temporaries->data[kFwCellStateQuantized] =
709 op_data->scratch_tensor_index + kFwCellStateQuantized;
710 TfLiteTensor* fw_cell_state_quantized;
711 TF_LITE_ENSURE_OK(context,
712 GetTemporarySafe(context, node, kFwCellStateQuantized,
713 &fw_cell_state_quantized));
714 fw_cell_state_quantized->type = fw_input_to_output_weights->type;
715 fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
716 if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
717 fw_cell_state->dims)) {
718 TfLiteIntArray* fw_cell_state_quantized_size =
719 TfLiteIntArrayCopy(fw_cell_state->dims);
720 TF_LITE_ENSURE_OK(context,
721 context->ResizeTensor(context, fw_cell_state_quantized,
722 fw_cell_state_quantized_size));
723 }
724 node->temporaries->data[kBwCellStateQuantized] =
725 op_data->scratch_tensor_index + kBwCellStateQuantized;
726 TfLiteTensor* bw_cell_state_quantized;
727 TF_LITE_ENSURE_OK(context,
728 GetTemporarySafe(context, node, kBwCellStateQuantized,
729 &bw_cell_state_quantized));
730 bw_cell_state_quantized->type = fw_input_to_output_weights->type;
731 bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
732 if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
733 bw_cell_state->dims)) {
734 TfLiteIntArray* bw_cell_state_quantized_size =
735 TfLiteIntArrayCopy(bw_cell_state->dims);
736 TF_LITE_ENSURE_OK(context,
737 context->ResizeTensor(context, bw_cell_state_quantized,
738 bw_cell_state_quantized_size));
739 }
740
741 // Allocate temporary tensors to store scaling factors and product scaling
742 // factors. The latter is a convenience storage which allows to quantize
743 // a vector once (which produces the scaling factors) and multiply it with
744 // different matrices (which requires multiplying the scaling factors with
745 // the scaling factor of the matrix).
746 node->temporaries->data[kInputScalingFactors] =
747 op_data->scratch_tensor_index + kInputScalingFactors;
748 TfLiteTensor* input_sf;
749 TF_LITE_ENSURE_OK(
750 context,
751 GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
752 input_sf->type = kTfLiteFloat32;
753 input_sf->allocation_type = kTfLiteArenaRw;
754 int scaling_dims[1] = {n_batch};
755 if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
756 TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
757 input_sf_size->data[0] = n_batch;
758 TF_LITE_ENSURE_OK(
759 context, context->ResizeTensor(context, input_sf, input_sf_size));
760 }
761 node->temporaries->data[kAuxInputScalingFactors] =
762 op_data->scratch_tensor_index + kAuxInputScalingFactors;
763 TfLiteTensor* aux_input_sf;
764 TF_LITE_ENSURE_OK(context,
765 GetTemporarySafe(context, node, kAuxInputScalingFactors,
766 &aux_input_sf));
767 aux_input_sf->type = kTfLiteFloat32;
768 aux_input_sf->allocation_type = kTfLiteArenaRw;
769 if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
770 TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
771 aux_input_sf_size->data[0] = n_batch;
772 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
773 aux_input_sf_size));
774 }
775 node->temporaries->data[kOutputStateScalingFactors] =
776 op_data->scratch_tensor_index + kOutputStateScalingFactors;
777 TfLiteTensor* output_state_sf;
778 TF_LITE_ENSURE_OK(
779 context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
780 &output_state_sf));
781 output_state_sf->type = kTfLiteFloat32;
782 output_state_sf->allocation_type = kTfLiteArenaRw;
783 if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
784 TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
785 output_state_sf_size->data[0] = n_batch;
786 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
787 output_state_sf_size));
788 }
789 node->temporaries->data[kProductScalingFactors] =
790 op_data->scratch_tensor_index + kProductScalingFactors;
791 TfLiteTensor* prod_scaling_factors;
792 TF_LITE_ENSURE_OK(context,
793 GetTemporarySafe(context, node, kProductScalingFactors,
794 &prod_scaling_factors));
795 prod_scaling_factors->type = kTfLiteFloat32;
796 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
797 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
798 scaling_dims)) {
799 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
800 prod_scaling_factors_size->data[0] = n_batch;
801 TF_LITE_ENSURE_OK(context,
802 context->ResizeTensor(context, prod_scaling_factors,
803 prod_scaling_factors_size));
804 }
805
806 // Allocate a temporary tensor to store the recovered cell weights. Since
807 // this is used for diagonal matrices, only need to store n_cell values.
808 node->temporaries->data[kRecoveredCellWeights] =
809 op_data->scratch_tensor_index + kRecoveredCellWeights;
810 TfLiteTensor* recovered_cell_weights;
811 TF_LITE_ENSURE_OK(context,
812 GetTemporarySafe(context, node, kRecoveredCellWeights,
813 &recovered_cell_weights));
814 recovered_cell_weights->type = kTfLiteFloat32;
815 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
816 int recovered_cell_dims[1] = {n_fw_cell};
817 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
818 recovered_cell_dims)) {
819 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
820 recovered_cell_weights_size->data[0] = n_fw_cell;
821 TF_LITE_ENSURE_OK(context,
822 context->ResizeTensor(context, recovered_cell_weights,
823 recovered_cell_weights_size));
824 }
825
826 // Allocate a temporary tensor to store the accumulated int32 values.
827 node->temporaries->data[kAccumScratchBuffer] =
828 op_data->scratch_tensor_index + kAccumScratchBuffer;
829 TfLiteTensor* accum_scratch;
830 TF_LITE_ENSURE_OK(
831 context,
832 GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
833 accum_scratch->type = kTfLiteInt32;
834 accum_scratch->allocation_type = kTfLiteArenaRw;
835 int n_cell = std::max(n_fw_cell, n_bw_cell);
836 if (has_aux_input) {
837 n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
838 n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
839 }
840 int accum_scratch_dims[2] = {n_cell, n_batch};
841 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
842 accum_scratch_dims)) {
843 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
844 accum_size->data[0] = n_cell;
845 accum_size->data[1] = n_batch;
846 TF_LITE_ENSURE_OK(
847 context, context->ResizeTensor(context, accum_scratch, accum_size));
848 }
849
850 // Allocate temporary tensors for storing zero-points.
851 node->temporaries->data[kInputZeroPoints] =
852 op_data->scratch_tensor_index + kInputZeroPoints;
853 TfLiteTensor* input_zp;
854 TF_LITE_ENSURE_OK(
855 context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
856 input_zp->type = kTfLiteFloat32;
857 input_zp->allocation_type = kTfLiteArenaRw;
858 if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
859 TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
860 input_zp_size->data[0] = n_batch;
861 TF_LITE_ENSURE_OK(
862 context, context->ResizeTensor(context, input_zp, input_zp_size));
863 }
864 node->temporaries->data[kAuxInputZeroPoints] =
865 op_data->scratch_tensor_index + kAuxInputZeroPoints;
866 TfLiteTensor* aux_input_zp;
867 TF_LITE_ENSURE_OK(
868 context,
869 GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
870 aux_input_zp->type = kTfLiteFloat32;
871 aux_input_zp->allocation_type = kTfLiteArenaRw;
872 if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
873 TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
874 aux_input_zp_size->data[0] = n_batch;
875 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
876 aux_input_zp_size));
877 }
878 node->temporaries->data[kOutputStateZeroPoints] =
879 op_data->scratch_tensor_index + kOutputStateZeroPoints;
880 TfLiteTensor* output_state_zp;
881 TF_LITE_ENSURE_OK(context,
882 GetTemporarySafe(context, node, kOutputStateZeroPoints,
883 &output_state_zp));
884 output_state_zp->type = kTfLiteFloat32;
885 output_state_zp->allocation_type = kTfLiteArenaRw;
886 if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
887 TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
888 output_state_zp_size->data[0] = n_batch;
889 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
890 output_state_zp_size));
891 }
892
893 // Allocate temporary tensors for caching row sums for hybrid zero-point
894 // calculations.
895 int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
896 if (has_aux_input) {
897 fw_row_sums_rows += fw_use_cifg ? 3 : 4;
898 }
899 const TfLiteTensor* fw_projection_weights =
900 GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
901 if (fw_projection_weights != nullptr) {
902 fw_row_sums_rows += ceil(static_cast<float>(n_fw_output) / n_fw_cell);
903 }
904 node->temporaries->data[kFwRowSums] =
905 op_data->scratch_tensor_index + kFwRowSums;
906 TfLiteTensor* fw_row_sums;
907 TF_LITE_ENSURE_OK(
908 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
909 fw_row_sums->type = kTfLiteInt32;
910 fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
911 int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
912 if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
913 TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
914 fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
915 fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
916 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
917 fw_hybrid_scratch_size));
918 }
919
920 int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
921 if (has_aux_input) {
922 bw_row_sums_rows += bw_use_cifg ? 3 : 4;
923 }
924 const TfLiteTensor* bw_projection_weights =
925 GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
926 if (bw_projection_weights != nullptr) {
927 bw_row_sums_rows += ceil(static_cast<float>(n_bw_output) / n_bw_cell);
928 }
929 node->temporaries->data[kBwRowSums] =
930 op_data->scratch_tensor_index + kBwRowSums;
931 TfLiteTensor* bw_row_sums;
932 TF_LITE_ENSURE_OK(
933 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
934 bw_row_sums->type = kTfLiteInt32;
935 bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
936 int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
937 if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
938 TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
939 bw_row_sums_size->data[0] = bw_row_sums_dims[0];
940 bw_row_sums_size->data[1] = bw_row_sums_dims[1];
941 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
942 bw_row_sums_size));
943 }
944
945 // Only allocate a temporary tensor for quantized auxiliary input if we are
946 // actually going to use it.
947 if (has_aux_input) {
948 node->temporaries->data[kAuxInputQuantized] =
949 op_data->scratch_tensor_index + kAuxInputQuantized;
950 TfLiteTensor* aux_input_quantized;
951 TF_LITE_ENSURE_OK(context,
952 GetTemporarySafe(context, node, kAuxInputQuantized,
953 &aux_input_quantized));
954 aux_input_quantized->type = fw_input_to_output_weights->type;
955 aux_input_quantized->allocation_type = kTfLiteArenaRw;
956 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
957 TfLiteIntArray* aux_input_quantized_size =
958 TfLiteIntArrayCopy(aux_input->dims);
959 TF_LITE_ENSURE_OK(context,
960 context->ResizeTensor(context, aux_input_quantized,
961 aux_input_quantized_size));
962 }
963 }
964 }
965 return kTfLiteOk;
966 }
967
968 // The LSTM Op engine.
Eval(TfLiteContext * context,TfLiteNode * node)969 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
970 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
971 node->builtin_data);
972 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
973 // Input tensor.
974 const TfLiteTensor* input;
975 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
976
977 // Tensors for the forward cell.
978 const TfLiteTensor* fw_input_to_input_weights =
979 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
980 const TfLiteTensor* fw_input_to_forget_weights;
981 TF_LITE_ENSURE_OK(context,
982 GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
983 &fw_input_to_forget_weights));
984 const TfLiteTensor* fw_input_to_cell_weights;
985 TF_LITE_ENSURE_OK(context,
986 GetInputSafe(context, node, kFwInputToCellWeightsTensor,
987 &fw_input_to_cell_weights));
988 const TfLiteTensor* fw_input_to_output_weights;
989 TF_LITE_ENSURE_OK(context,
990 GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
991 &fw_input_to_output_weights));
992
993 const TfLiteTensor* fw_recurrent_to_input_weights =
994 GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
995 const TfLiteTensor* fw_recurrent_to_forget_weights;
996 TF_LITE_ENSURE_OK(
997 context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
998 &fw_recurrent_to_forget_weights));
999 const TfLiteTensor* fw_recurrent_to_cell_weights;
1000 TF_LITE_ENSURE_OK(context,
1001 GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
1002 &fw_recurrent_to_cell_weights));
1003 const TfLiteTensor* fw_recurrent_to_output_weights;
1004 TF_LITE_ENSURE_OK(
1005 context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
1006 &fw_recurrent_to_output_weights));
1007
1008 const TfLiteTensor* fw_cell_to_input_weights =
1009 GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
1010 const TfLiteTensor* fw_cell_to_forget_weights =
1011 GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor);
1012 const TfLiteTensor* fw_cell_to_output_weights =
1013 GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor);
1014
1015 const TfLiteTensor* fw_input_gate_bias =
1016 GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
1017 const TfLiteTensor* fw_forget_gate_bias;
1018 TF_LITE_ENSURE_OK(context,
1019 GetInputSafe(context, node, kFwForgetGateBiasTensor,
1020 &fw_forget_gate_bias));
1021 const TfLiteTensor* fw_cell_gate_bias;
1022 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
1023 &fw_cell_gate_bias));
1024 const TfLiteTensor* fw_output_gate_bias;
1025 TF_LITE_ENSURE_OK(context,
1026 GetInputSafe(context, node, kFwOutputGateBiasTensor,
1027 &fw_output_gate_bias));
1028
1029 const TfLiteTensor* fw_projection_weights =
1030 GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
1031 const TfLiteTensor* fw_projection_bias =
1032 GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
1033
1034 TfLiteTensor* fw_activation_state =
1035 GetVariableInput(context, node, kFwInputActivationStateTensor);
1036 TFLITE_DCHECK(fw_activation_state != nullptr);
1037 TfLiteTensor* fw_cell_state =
1038 GetVariableInput(context, node, kFwInputCellStateTensor);
1039 TFLITE_DCHECK(fw_cell_state != nullptr);
1040 TfLiteTensor* fw_output;
1041 TF_LITE_ENSURE_OK(context,
1042 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
1043
1044 // Tensors for the backward cell.
1045 const TfLiteTensor* bw_input_to_input_weights =
1046 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
1047 const TfLiteTensor* bw_input_to_forget_weights;
1048 TF_LITE_ENSURE_OK(context,
1049 GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
1050 &bw_input_to_forget_weights));
1051 const TfLiteTensor* bw_input_to_cell_weights;
1052 TF_LITE_ENSURE_OK(context,
1053 GetInputSafe(context, node, kBwInputToCellWeightsTensor,
1054 &bw_input_to_cell_weights));
1055 const TfLiteTensor* bw_input_to_output_weights;
1056 TF_LITE_ENSURE_OK(context,
1057 GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
1058 &bw_input_to_output_weights));
1059
1060 const TfLiteTensor* bw_recurrent_to_input_weights =
1061 GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
1062 const TfLiteTensor* bw_recurrent_to_forget_weights;
1063 TF_LITE_ENSURE_OK(
1064 context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
1065 &bw_recurrent_to_forget_weights));
1066 const TfLiteTensor* bw_recurrent_to_cell_weights;
1067 TF_LITE_ENSURE_OK(context,
1068 GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
1069 &bw_recurrent_to_cell_weights));
1070 const TfLiteTensor* bw_recurrent_to_output_weights;
1071 TF_LITE_ENSURE_OK(
1072 context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
1073 &bw_recurrent_to_output_weights));
1074
1075 const TfLiteTensor* bw_cell_to_input_weights =
1076 GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
1077 const TfLiteTensor* bw_cell_to_forget_weights =
1078 GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor);
1079 const TfLiteTensor* bw_cell_to_output_weights =
1080 GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor);
1081
1082 const TfLiteTensor* bw_input_gate_bias =
1083 GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
1084 const TfLiteTensor* bw_forget_gate_bias;
1085 TF_LITE_ENSURE_OK(context,
1086 GetInputSafe(context, node, kBwForgetGateBiasTensor,
1087 &bw_forget_gate_bias));
1088 const TfLiteTensor* bw_cell_gate_bias;
1089 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
1090 &bw_cell_gate_bias));
1091 const TfLiteTensor* bw_output_gate_bias;
1092 TF_LITE_ENSURE_OK(context,
1093 GetInputSafe(context, node, kBwOutputGateBiasTensor,
1094 &bw_output_gate_bias));
1095
1096 const TfLiteTensor* bw_projection_weights =
1097 GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
1098 const TfLiteTensor* bw_projection_bias =
1099 GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
1100
1101 // State tensors.
1102 TfLiteTensor* bw_activation_state =
1103 GetVariableInput(context, node, kBwInputActivationStateTensor);
1104 TFLITE_DCHECK(bw_activation_state != nullptr);
1105 TfLiteTensor* bw_cell_state =
1106 GetVariableInput(context, node, kBwInputCellStateTensor);
1107 TFLITE_DCHECK(bw_cell_state != nullptr);
1108 TfLiteTensor* bw_output = params->merge_outputs
1109 ? nullptr
1110 : GetOutput(context, node, kBwOutputTensor);
1111
1112 // Temporary tensors.
1113 TfLiteTensor* fw_scratch_buffer;
1114 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
1115 &fw_scratch_buffer));
1116 TfLiteTensor* bw_scratch_buffer;
1117 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
1118 &bw_scratch_buffer));
1119
1120 // (Optional) auxiliary inputs.
1121 const TfLiteTensor* aux_input =
1122 GetOptionalInputTensor(context, node, kAuxInputTensor);
1123 const TfLiteTensor* fw_aux_input_to_input_weights =
1124 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
1125 const TfLiteTensor* fw_aux_input_to_forget_weights =
1126 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
1127 const TfLiteTensor* fw_aux_input_to_cell_weights =
1128 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
1129 const TfLiteTensor* fw_aux_input_to_output_weights =
1130 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
1131 const TfLiteTensor* bw_aux_input_to_input_weights =
1132 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
1133 const TfLiteTensor* bw_aux_input_to_forget_weights =
1134 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
1135 const TfLiteTensor* bw_aux_input_to_cell_weights =
1136 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
1137 const TfLiteTensor* bw_aux_input_to_output_weights =
1138 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
1139
1140 const bool has_previous_bw_output = (aux_input != nullptr);
1141 const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
1142
1143 // Populate a TfLiteLSTMParams struct for the evaluation functions.
1144 TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
1145 params->proj_clip, kTfLiteLSTMFullKernel,
1146 params->asymmetric_quantize_inputs};
1147
1148 const int bw_output_offset =
1149 params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
1150 const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
1151
1152 const bool time_major = params->time_major;
1153
1154 // We want to cover the following cases:
1155 //
1156 // If not stacking (not connected after other bidi lstms):
1157 // both fw & bw will just use `input`; aux_input will be null.
1158 //
1159 // If stacking with cross_links, TensorFlow equivalent
1160 // (tf.contrib.rnn.stack_bidirectional_rnn):
1161 // both fw & bw will use `input`, but aux_input will be none null.
1162 // Note, this time, whether connected after other bidi lstms both works.
1163 //
1164 // If stacking without cross_links, but connected after other bidi lstms,
1165 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
1166 // fw will use `input`, bw will use aux_input, and the `real aux_input`
1167 // will be null.
1168
1169 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
1170 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
1171 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
1172
1173 switch (fw_input_to_output_weights->type) {
1174 case kTfLiteFloat32: {
1175 TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
1176 input, fw_input_to_input_weights, fw_input_to_forget_weights,
1177 fw_input_to_cell_weights, fw_input_to_output_weights,
1178 fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
1179 fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
1180 fw_cell_to_input_weights, fw_cell_to_forget_weights,
1181 fw_cell_to_output_weights,
1182 /*input_layer_norm_coefficients=*/nullptr,
1183 /*forget_layer_norm_coefficients=*/nullptr,
1184 /*cell_layer_norm_coefficients=*/nullptr,
1185 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1186 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1187 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1188 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1189 fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
1190 &lstm_params,
1191 /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1192 fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
1193 TF_LITE_ENSURE_OK(context, fw_pass_status);
1194
1195 TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
1196 bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
1197 bw_input_to_cell_weights, bw_input_to_output_weights,
1198 bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
1199 bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
1200 bw_cell_to_input_weights, bw_cell_to_forget_weights,
1201 bw_cell_to_output_weights,
1202 /*input_layer_norm_coefficients=*/nullptr,
1203 /*forget_layer_norm_coefficients=*/nullptr,
1204 /*cell_layer_norm_coefficients=*/nullptr,
1205 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1206 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1207 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1208 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1209 bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
1210 &lstm_params,
1211 /*forward_sequence=*/false, time_major, bw_output_offset,
1212 bw_scratch_buffer, bw_activation_state, bw_cell_state,
1213 actual_bw_output);
1214 TF_LITE_ENSURE_OK(context, bw_pass_status);
1215 return kTfLiteOk;
1216 }
1217 case kTfLiteUInt8:
1218 case kTfLiteInt8: {
1219 TfLiteTensor* input_quantized;
1220 TF_LITE_ENSURE_OK(
1221 context,
1222 GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
1223 TfLiteTensor* fw_activation_state_quantized;
1224 TF_LITE_ENSURE_OK(
1225 context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
1226 &fw_activation_state_quantized));
1227 TfLiteTensor* bw_activation_state_quantized;
1228 TF_LITE_ENSURE_OK(
1229 context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
1230 &bw_activation_state_quantized));
1231 TfLiteTensor* fw_cell_state_quantized;
1232 TF_LITE_ENSURE_OK(context,
1233 GetTemporarySafe(context, node, kFwCellStateQuantized,
1234 &fw_cell_state_quantized));
1235 TfLiteTensor* bw_cell_state_quantized;
1236 TF_LITE_ENSURE_OK(context,
1237 GetTemporarySafe(context, node, kBwCellStateQuantized,
1238 &bw_cell_state_quantized));
1239 TfLiteTensor* prod_scaling_factors;
1240 TF_LITE_ENSURE_OK(context,
1241 GetTemporarySafe(context, node, kProductScalingFactors,
1242 &prod_scaling_factors));
1243 TfLiteTensor* recovered_cell_weights;
1244 TF_LITE_ENSURE_OK(context,
1245 GetTemporarySafe(context, node, kRecoveredCellWeights,
1246 &recovered_cell_weights));
1247 TfLiteTensor* aux_input_quantized =
1248 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
1249 : nullptr;
1250 TfLiteTensor* accum_scratch;
1251 TF_LITE_ENSURE_OK(
1252 context,
1253 GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
1254 TfLiteTensor* fw_row_sums;
1255 TF_LITE_ENSURE_OK(
1256 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
1257 TfLiteTensor* bw_row_sums;
1258 TF_LITE_ENSURE_OK(
1259 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
1260 const int fw_row_sums_size = fw_row_sums->dims->data[0];
1261 const int bw_row_sums_size = bw_row_sums->dims->data[0];
1262 TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
1263 input, fw_input_to_input_weights,
1264 /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
1265 /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
1266 /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
1267 /*input_to_output_weights_ledger*/ nullptr,
1268 fw_recurrent_to_input_weights,
1269 /*recurrent_to_input_weights_ledger*/ nullptr,
1270 fw_recurrent_to_forget_weights,
1271 /*recurrent_to_forget_weights_ledger*/ nullptr,
1272 fw_recurrent_to_cell_weights,
1273 /*recurrent_to_cell_weights_ledger*/ nullptr,
1274 fw_recurrent_to_output_weights,
1275 /*recurrent_to_output_weights_ledger*/ nullptr,
1276 fw_cell_to_input_weights, fw_cell_to_forget_weights,
1277 fw_cell_to_output_weights,
1278 /*input_layer_norm_coefficients=*/nullptr,
1279 /*forget_layer_norm_coefficients=*/nullptr,
1280 /*cell_layer_norm_coefficients=*/nullptr,
1281 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1282 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1283 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1284 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1285 fw_output_gate_bias, fw_projection_weights,
1286 /*projection_weights_ledger*/ nullptr, fw_projection_bias,
1287 &lstm_params,
1288 /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
1289 fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1290 GetTemporary(context, node, kAuxInputScalingFactors),
1291 GetTemporary(context, node, kOutputStateScalingFactors),
1292 prod_scaling_factors, recovered_cell_weights, input_quantized,
1293 aux_input_quantized, fw_activation_state_quantized,
1294 fw_cell_state_quantized, fw_activation_state, fw_cell_state,
1295 accum_scratch, fw_output,
1296 GetTemporary(context, node, kInputZeroPoints),
1297 GetTemporary(context, node, kAuxInputZeroPoints),
1298 GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
1299 fw_row_sums_size, &op_data->compute_fw_row_sums,
1300 CpuBackendContext::GetFromContext(context));
1301 TF_LITE_ENSURE_OK(context, fw_pass_status);
1302
1303 TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
1304 bw_input, bw_input_to_input_weights,
1305 /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
1306 /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
1307 /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
1308 /*input_to_output_weights_ledger*/ nullptr,
1309 bw_recurrent_to_input_weights,
1310 /*recurrent_to_input_weights_ledger*/ nullptr,
1311 bw_recurrent_to_forget_weights,
1312 /*recurrent_to_forget_weights_ledger*/ nullptr,
1313 bw_recurrent_to_cell_weights,
1314 /*recurrent_to_cell_weights_ledger*/ nullptr,
1315 bw_recurrent_to_output_weights,
1316 /*recurrent_to_output_weights_ledger*/ nullptr,
1317 bw_cell_to_input_weights, bw_cell_to_forget_weights,
1318 bw_cell_to_output_weights,
1319 /*input_layer_norm_coefficients=*/nullptr,
1320 /*forget_layer_norm_coefficients=*/nullptr,
1321 /*cell_layer_norm_coefficients=*/nullptr,
1322 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1323 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1324 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1325 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1326 bw_output_gate_bias, bw_projection_weights,
1327 /*projection_weights_ledger*/ nullptr, bw_projection_bias,
1328 &lstm_params,
1329 /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
1330 bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1331 GetTemporary(context, node, kAuxInputScalingFactors),
1332 GetTemporary(context, node, kOutputStateScalingFactors),
1333 prod_scaling_factors, recovered_cell_weights, input_quantized,
1334 aux_input_quantized, bw_activation_state_quantized,
1335 bw_cell_state_quantized, bw_activation_state, bw_cell_state,
1336 accum_scratch, actual_bw_output,
1337 GetTemporary(context, node, kInputZeroPoints),
1338 GetTemporary(context, node, kAuxInputZeroPoints),
1339 GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
1340 bw_row_sums_size, &op_data->compute_bw_row_sums,
1341 CpuBackendContext::GetFromContext(context));
1342 TF_LITE_ENSURE_OK(context, bw_pass_status);
1343 return kTfLiteOk;
1344 }
1345 default:
1346 TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1347 TfLiteTypeGetName(fw_input_to_output_weights->type));
1348 return kTfLiteError;
1349 }
1350 return kTfLiteOk;
1351 }
1352
1353 } // namespace bidirectional_sequence_lstm
1354
Register_BIDIRECTIONAL_SEQUENCE_LSTM()1355 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
1356 static TfLiteRegistration r = {
1357 bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
1358 bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
1359 return &r;
1360 }
1361
1362 } // namespace builtin
1363 } // namespace ops
1364 } // namespace tflite
1365