1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16 
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/c_api_internal.h"
20 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
21 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace lstm_eval {
28 
29 namespace {
30 
31 // Small float to avoid divergence during calculation of deviation for layer
32 // norm lstm.
33 const float kLayerNormEpsilon = 1e-8;
34 
35 // Performs an LSTM batch inference step for input specified by input_ptr_batch.
36 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
37 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
38 // parameters:
39 //  - params: various LSTM params including activation, clipping, etc.,
40 //  - n_batch: size of batch,
41 //  - n_cell: number of cells (or units),
42 //  - n_input: the input size,
43 //  - n_aux_input: the auxilary input size.
44 //  - n_output: the output size.
45 //  - output_batch_leading_dim: the leading dimension of the output buffer.
46 //
47 // LSTM weights:
48 // Input weights of size 'n_cell * n_input':
49 //   input_to_input_weights            - optional (can be nullptr)
50 //   input_to_forget_weights
51 //   input_to_cell_weights
52 //   input_to_output_weights
53 // Auxilary input weights of size 'n_cell * n_aux_input':
54 //   aux_input_to_input_weights        - optional
55 //   aux_input_to_forget_weights       - optional
56 //   aux_input_to_cell_weights         - optional
57 //   aux_input_to_output_weights       - optional
58 // Recurrent weights of size 'n_cell * n_output':
59 //   recurrent_to_input_weights        - optional
60 //   recurrent_to_forget_weights
61 //   recurrent_to_cell_weights
62 //   recurrent_to_input_weights
63 // Peephole weights of size 'n_cell', representing diagonal matrices.
64 //   cell_to_input_weights             - optional
65 //   cell_to_cell_weights              - optional
66 //   cell_to_output_weights            - optional
67 // Projection weights of size 'n_output * n_cell'
68 //   projection_weights_ptr            - optional
69 // Gate biases of size 'n_cell':
70 //   input_gate_bias_ptr               - optional
71 //   forget_gate_bias_ptr
72 //   cell_gate_bias_ptr
73 //   output_gate_bias_ptr
74 //
75 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
76 //   input_layer_norm_coefficients_ptr  - optional
77 //   forget_layer_norm_coefficients_ptr - optional
78 //   cell_layer_norm_coefficients_ptr   - optional
79 //   output_layer_norm_coefficients_ptr - optional
80 //
81 // The pointers to the cell and output state and the output are updated.
82 //
83 // The pointers with the suffix "_batch" point to data aligned in batch_major
84 // order, and each step processes batch_size many inputs from input_ptr_batch,
85 // and updates batch_size many cell and output states.
86 //
87 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
88 // output tensor, and in most cases will be equal to n_output. It is usually not
89 // when we want to store the LSTM output into a slice of the output tensor, e.g.
90 // for bidirectional LSTMs with merge_outputs. In this case, the batched
91 // operations cannot be used since they assume that the batched outputs are
92 // contiguous, and we manually loop over the batched outputs.
LstmStepWithAuxInput(const float * input_ptr_batch,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr_batch,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * output_ptr_batch)93 inline void LstmStepWithAuxInput(
94     const float* input_ptr_batch, const float* input_to_input_weights_ptr,
95     const float* input_to_forget_weights_ptr,
96     const float* input_to_cell_weights_ptr,
97     const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
98     const float* aux_input_to_input_weights_ptr,
99     const float* aux_input_to_forget_weights_ptr,
100     const float* aux_input_to_cell_weights_ptr,
101     const float* aux_input_to_output_weights_ptr,
102     const float* recurrent_to_input_weights_ptr,
103     const float* recurrent_to_forget_weights_ptr,
104     const float* recurrent_to_cell_weights_ptr,
105     const float* recurrent_to_output_weights_ptr,
106     const float* cell_to_input_weights_ptr,
107     const float* cell_to_forget_weights_ptr,
108     const float* cell_to_output_weights_ptr,
109     const float* input_layer_norm_coefficients_ptr,
110     const float* forget_layer_norm_coefficients_ptr,
111     const float* cell_layer_norm_coefficients_ptr,
112     const float* output_layer_norm_coefficients_ptr,
113     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
114     const float* cell_bias_ptr, const float* output_gate_bias_ptr,
115     const float* projection_weights_ptr, const float* projection_bias_ptr,
116     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
117     int n_aux_input, int n_output, int output_batch_leading_dim,
118     float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
119     float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
120     float* output_ptr_batch) {
121   // Since we have already checked that weights are all there or none, we can
122   // check the existence of only one to the get the condition.
123   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
124   const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
125   const bool is_layer_norm_lstm =
126       (forget_layer_norm_coefficients_ptr != nullptr);
127 
128   // Initialize scratch buffers with bias for regular lstm or initialize with
129   // zero for layer norm lstm.
130   if (is_layer_norm_lstm) {
131     if (!use_cifg) {
132       tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
133     }
134     tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
135     tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
136     tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
137   } else {
138     if (!use_cifg) {
139       tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
140                                             n_batch, input_gate_scratch);
141     }
142     tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
143                                           forget_gate_scratch);
144     tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
145                                           cell_scratch);
146     tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
147                                           output_gate_scratch);
148   }
149 
150   // For each batch and cell: compute input_weight * input.
151   if (!use_cifg) {
152     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
153         input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
154         input_gate_scratch, /*result_stride=*/1);
155   }
156 
157   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
158       input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
159       forget_gate_scratch, /*result_stride=*/1);
160   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
161       input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
162       cell_scratch, /*result_stride=*/1);
163   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
164       input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
165       output_gate_scratch, /*result_stride=*/1);
166 
167   // If auxiliary input is available then compute aux_input_weight * aux_input
168   if (aux_input_ptr_batch != nullptr) {
169     if (!use_cifg) {
170       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
171           aux_input_to_input_weights_ptr, n_cell, n_aux_input,
172           aux_input_ptr_batch, n_batch, input_gate_scratch,
173           /*result_stride=*/1);
174     }
175 
176     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
177         aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
178         aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
179     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
180         aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
181         n_batch, cell_scratch, /*result_stride=*/1);
182     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
183         aux_input_to_output_weights_ptr, n_cell, n_aux_input,
184         aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
185   }
186 
187   // For each batch and cell: compute recurrent_weight * output_state.
188   if (!use_cifg) {
189     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
190         recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
191         n_batch, input_gate_scratch, /*result_stride=*/1);
192   }
193   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
194       recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
195       n_batch, forget_gate_scratch,
196       /*result_stride=*/1);
197   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
198       recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
199       n_batch, cell_scratch, /*result_stride=*/1);
200   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
201       recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
202       n_batch, output_gate_scratch,
203       /*result_stride=*/1);
204 
205   // For each batch and cell: update input gate.
206   if (!use_cifg) {
207     if (use_peephole) {
208       tensor_utils::VectorBatchVectorCwiseProductAccumulate(
209           cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
210           input_gate_scratch);
211     }
212     if (is_layer_norm_lstm) {
213       tensor_utils::MeanStddevNormalization(input_gate_scratch,
214                                             input_gate_scratch, n_cell, n_batch,
215                                             kLayerNormEpsilon);
216       tensor_utils::VectorBatchVectorCwiseProduct(
217           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
218           n_batch, input_gate_scratch);
219       tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
220                                          input_gate_scratch);
221     }
222     tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
223                                        input_gate_scratch);
224   }
225 
226   // For each batch and cell: update forget gate.
227   if (use_peephole) {
228     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
229         cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
230         forget_gate_scratch);
231   }
232   if (is_layer_norm_lstm) {
233     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
234                                           forget_gate_scratch, n_cell, n_batch,
235                                           kLayerNormEpsilon);
236     tensor_utils::VectorBatchVectorCwiseProduct(
237         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
238         n_batch, forget_gate_scratch);
239     tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
240                                        forget_gate_scratch);
241   }
242   tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
243                                      forget_gate_scratch);
244 
245   // For each batch and cell: update the cell.
246   tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
247                                          n_batch * n_cell, cell_state_ptr);
248   if (is_layer_norm_lstm) {
249     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
250                                           n_batch, kLayerNormEpsilon);
251     tensor_utils::VectorBatchVectorCwiseProduct(
252         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
253         cell_scratch);
254     tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
255                                        cell_scratch);
256   }
257   tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
258                                         params->activation, cell_scratch);
259   if (use_cifg) {
260     tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
261                              forget_gate_scratch);
262     tensor_utils::VectorVectorCwiseProductAccumulate(
263         cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
264   } else {
265     tensor_utils::VectorVectorCwiseProductAccumulate(
266         cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
267   }
268   if (params->cell_clip > 0.0) {
269     tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
270                              params->cell_clip, cell_state_ptr);
271   }
272 
273   // For each batch and cell: update the output gate.
274   if (use_peephole) {
275     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
276         cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
277         output_gate_scratch);
278   }
279   if (is_layer_norm_lstm) {
280     tensor_utils::MeanStddevNormalization(output_gate_scratch,
281                                           output_gate_scratch, n_cell, n_batch,
282                                           kLayerNormEpsilon);
283     tensor_utils::VectorBatchVectorCwiseProduct(
284         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
285         n_batch, output_gate_scratch);
286     tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
287                                        output_gate_scratch);
288   }
289   tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
290                                      output_gate_scratch);
291   tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
292                                         params->activation, cell_scratch);
293   tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
294                                          n_batch * n_cell, output_gate_scratch);
295 
296   const bool use_projection_weight = (projection_weights_ptr != nullptr);
297   const bool use_projection_bias = (projection_bias_ptr != nullptr);
298 
299   // For each batch: update the projection and output_state. Note that since
300   // the output batch rows may not be contiguous (output_batch_leading_dim !=
301   // n_output), we unroll the batched operations where this is the case.
302   if (output_batch_leading_dim == n_output) {
303     if (use_projection_weight) {
304       if (use_projection_bias) {
305         tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
306                                               n_batch, output_ptr_batch);
307       } else {
308         tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
309       }
310       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
311           projection_weights_ptr, n_output, n_cell, output_gate_scratch,
312           n_batch, output_ptr_batch, /*result_stride=*/1);
313       if (params->proj_clip > 0.0) {
314         tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
315                                  params->proj_clip, output_ptr_batch);
316       }
317     } else {
318       tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
319                                output_ptr_batch);
320     }
321     tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
322                              output_state_ptr);
323   } else {
324     if (use_projection_weight) {
325       if (use_projection_bias) {
326         for (int k = 0; k < n_batch; k++) {
327           tensor_utils::CopyVector(
328               projection_bias_ptr, n_output,
329               output_ptr_batch + k * output_batch_leading_dim);
330         }
331       } else {
332         for (int k = 0; k < n_batch; k++) {
333           tensor_utils::ZeroVector(
334               output_ptr_batch + k * output_batch_leading_dim, n_output);
335         }
336       }
337       for (int k = 0; k < n_batch; k++) {
338         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
339             projection_weights_ptr, n_output, n_cell,
340             output_gate_scratch + k * n_cell,
341             /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
342             /*result_stride=*/1);
343         if (params->proj_clip > 0.0) {
344           tensor_utils::ClipVector(
345               output_ptr_batch + k * output_batch_leading_dim, n_output,
346               params->proj_clip,
347               output_ptr_batch + k * output_batch_leading_dim);
348         }
349       }
350     } else {
351       for (int k = 0; k < n_batch; k++) {
352         tensor_utils::CopyVector(
353             output_gate_scratch + k * n_output, n_output,
354             output_ptr_batch + k * output_batch_leading_dim);
355       }
356     }
357     for (int k = 0; k < n_batch; k++) {
358       tensor_utils::CopyVector(output_ptr_batch + k * output_batch_leading_dim,
359                                n_output, output_state_ptr + k * n_output);
360     }
361   }
362 }
363 
364 // Same as above but with quantized weight matrices. In detail:
365 // Input of size 'n_batch * n_input':
366 //   input_ptr_batch
367 //
368 // LSTM weights:
369 // Quantized input weights of size 'n_cell * n_input':
370 //   input_to_input_weights            - optional (can be nullptr)
371 //   input_to_forget_weights
372 //   input_to_cell_weights
373 //   input_to_input_weights
374 // Quantized auxilary input weights of size 'n_cell * n_aux_input':
375 //   aux_input_to_input_weights        - optional
376 //   aux_input_to_forget_weights       - optional
377 //   aux_input_to_cell_weights         - optional
378 //   aux_input_to_output_weights       - optional
379 // Quantized recurrent weights of size 'n_cell * n_output':
380 //   recurrent_to_input_weights        - optional
381 //   recurrent_to_forget_weights
382 //   recurrent_to_cell_weights
383 //   recurrent_to_input_weights
384 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
385 //   cell_to_input_weights             - optional
386 //   cell_to_cell_weights              - optional
387 //   cell_to_output_weights            - optional
388 // Quantized projection weights of size 'n_output * n_cell'
389 //   projection_weights_ptr            - optional
390 // Weight scales (scalars) for each of the weights above.
391 //   input_to_input_weights_scale      - optional
392 //   input_to_forget_weights_scale
393 //   input_to_cell_weights_scale
394 //   input_to_output_weights_scale
395 //   aux_input_to_input_weights_scale  - optional
396 //   aux_input_to_forget_weights_scale - optional
397 //   aux_input_to_cell_weights_scale   - optional
398 //   aux_input_to_output_weights_scale - optional
399 //   recurrent_to_input_weights_scale  - optional
400 //   recurrent_to_forget_weights_scale
401 //   recurrent_to_cell_weights_scale
402 //   recurrent_to_output_weights_scale
403 //   cell_to_input_weights_scale,
404 //   cell_to_forget_weights_scale,
405 //   cell_to_output_weights_scale,
406 //   projection_weights_scale          - optional
407 // Gate biases of size 'n_cell':
408 //   input_gate_bias_ptr               - optional
409 //   forget_gate_bias_ptr
410 //   cell_gate_bias_ptr
411 //   output_gate_bias_ptr
412 //
413 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
414 //   input_layer_norm_coefficients_ptr  - optional
415 //   forget_layer_norm_coefficients_ptr - optional
416 //   cell_layer_norm_coefficients_ptr   - optional
417 //   output_layer_norm_coefficients_ptr - optional
418 //
419 // Temporary pre-allocated storage for quantized values:
420 //   quantized_input_ptr_batch (same size as input_ptr_batch)
421 //   quantized_output_state_ptr (same size as output_state_ptr)
422 //   quantized_cell_state_ptr (same size as cell_state_ptr)
423 // Temporary pre-allocated storage for recovered values:
424 //   recovered_cell_weights (same size as cell_to_*_weights)
425 //
426 // Outputs:
427 //   output_state_ptr - size 'n_batch * n_output'
428 //   cell_state_ptr   - size 'n_batch * n_cell'
429 //   output_ptr_batch - size 'n_batch * output_batch_leading_dim'
LstmStepWithAuxInput(const float * input_ptr_batch,const int8_t * input_to_input_weights_ptr,float input_to_input_weights_scale,const int8_t * input_to_forget_weights_ptr,float input_to_forget_weights_scale,const int8_t * input_to_cell_weights_ptr,float input_to_cell_weights_scale,const int8_t * input_to_output_weights_ptr,float input_to_output_weights_scale,const float * aux_input_ptr_batch,const int8_t * aux_input_to_input_weights_ptr,float aux_input_to_input_weights_scale,const int8_t * aux_input_to_forget_weights_ptr,float aux_input_to_forget_weights_scale,const int8_t * aux_input_to_cell_weights_ptr,float aux_input_to_cell_weights_scale,const int8_t * aux_input_to_output_weights_ptr,float aux_input_to_output_weights_scale,const int8_t * recurrent_to_input_weights_ptr,float recurrent_to_input_weights_scale,const int8_t * recurrent_to_forget_weights_ptr,float recurrent_to_forget_weights_scale,const int8_t * recurrent_to_cell_weights_ptr,float recurrent_to_cell_weights_scale,const int8_t * recurrent_to_output_weights_ptr,float recurrent_to_output_weights_scale,const int8_t * cell_to_input_weights_ptr,float cell_to_input_weights_scale,const int8_t * cell_to_forget_weights_ptr,float cell_to_forget_weights_scale,const int8_t * cell_to_output_weights_ptr,float cell_to_output_weights_scale,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const int8_t * projection_weights_ptr,float projection_weights_scale,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * scaling_factors,float * product_scaling_factors,float * recovered_cell_weights,int8_t * quantized_input_ptr_batch,int8_t * quantized_aux_input_ptr_batch,int8_t * quantized_output_state_ptr,int8_t * quantized_cell_state_ptr,float * output_state_ptr,float * cell_state_ptr,float * output_ptr_batch)430 inline void LstmStepWithAuxInput(
431     const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
432     float input_to_input_weights_scale,
433     const int8_t* input_to_forget_weights_ptr,
434     float input_to_forget_weights_scale,
435     const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
436     const int8_t* input_to_output_weights_ptr,
437     float input_to_output_weights_scale, const float* aux_input_ptr_batch,
438     const int8_t* aux_input_to_input_weights_ptr,
439     float aux_input_to_input_weights_scale,
440     const int8_t* aux_input_to_forget_weights_ptr,
441     float aux_input_to_forget_weights_scale,
442     const int8_t* aux_input_to_cell_weights_ptr,
443     float aux_input_to_cell_weights_scale,
444     const int8_t* aux_input_to_output_weights_ptr,
445     float aux_input_to_output_weights_scale,
446     const int8_t* recurrent_to_input_weights_ptr,
447     float recurrent_to_input_weights_scale,
448     const int8_t* recurrent_to_forget_weights_ptr,
449     float recurrent_to_forget_weights_scale,
450     const int8_t* recurrent_to_cell_weights_ptr,
451     float recurrent_to_cell_weights_scale,
452     const int8_t* recurrent_to_output_weights_ptr,
453     float recurrent_to_output_weights_scale,
454     const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
455     const int8_t* cell_to_forget_weights_ptr,
456     float cell_to_forget_weights_scale,
457     const int8_t* cell_to_output_weights_ptr,
458     float cell_to_output_weights_scale,
459     const float* input_layer_norm_coefficients_ptr,
460     const float* forget_layer_norm_coefficients_ptr,
461     const float* cell_layer_norm_coefficients_ptr,
462     const float* output_layer_norm_coefficients_ptr,
463     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
464     const float* cell_bias_ptr, const float* output_gate_bias_ptr,
465     const int8_t* projection_weights_ptr, float projection_weights_scale,
466     const float* projection_bias_ptr, const TfLiteLSTMParams* params,
467     int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
468     int output_batch_leading_dim, float* input_gate_scratch,
469     float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
470     float* scaling_factors, float* product_scaling_factors,
471     float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
472     int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
473     int8_t* quantized_cell_state_ptr, float* output_state_ptr,
474     float* cell_state_ptr, float* output_ptr_batch) {
475   // Since we have already checked that weights are all there or none, we
476   // can check the existence of only one to the get the condition.
477   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
478   const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
479   const bool is_layer_norm_lstm =
480       (forget_layer_norm_coefficients_ptr != nullptr);
481 
482   // Initialize scratch buffers with bias.
483   if (is_layer_norm_lstm) {
484     if (!use_cifg) {
485       tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
486     }
487     tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
488     tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
489     tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
490   } else {
491     if (!use_cifg) {
492       tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
493                                             n_batch, input_gate_scratch);
494     }
495     tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
496                                           forget_gate_scratch);
497     tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
498                                           cell_scratch);
499     tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
500                                           output_gate_scratch);
501   }
502 
503   if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
504     // Save quantization and matmul computation for all zero input.
505     float unused_min, unused_max;
506     for (int b = 0; b < n_batch; ++b) {
507       const int offset = b * n_input;
508       tensor_utils::SymmetricQuantizeFloats(
509           input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
510           &unused_min, &unused_max, &scaling_factors[b]);
511     }
512     // For each batch and cell: compute input_weight * input.
513     if (!use_cifg) {
514       for (int b = 0; b < n_batch; ++b) {
515         product_scaling_factors[b] =
516             scaling_factors[b] * input_to_input_weights_scale;
517       }
518       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
519           input_to_input_weights_ptr, n_cell, n_input,
520           quantized_input_ptr_batch, product_scaling_factors, n_batch,
521           input_gate_scratch, /*result_stride=*/1);
522     }
523 
524     for (int b = 0; b < n_batch; ++b) {
525       product_scaling_factors[b] =
526           scaling_factors[b] * input_to_forget_weights_scale;
527     }
528     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
529         input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
530         product_scaling_factors, n_batch, forget_gate_scratch,
531         /*result_stride=*/1);
532 
533     for (int b = 0; b < n_batch; ++b) {
534       product_scaling_factors[b] =
535           scaling_factors[b] * input_to_cell_weights_scale;
536     }
537     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
538         input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
539         product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
540 
541     for (int b = 0; b < n_batch; ++b) {
542       product_scaling_factors[b] =
543           scaling_factors[b] * input_to_output_weights_scale;
544     }
545     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
546         input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
547         product_scaling_factors, n_batch, output_gate_scratch,
548         /*result_stride=*/1);
549   }
550 
551   if (aux_input_ptr_batch != nullptr &&
552       !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
553     // Save quantization and matmul computation for all zero input.
554     float unused_min, unused_max;
555     for (int b = 0; b < n_batch; ++b) {
556       const int offset = b * n_input;
557       tensor_utils::SymmetricQuantizeFloats(
558           aux_input_ptr_batch + offset, n_input,
559           quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
560           &scaling_factors[b]);
561     }
562     // For each batch and cell: compute input_weight * input.
563     if (!use_cifg) {
564       for (int b = 0; b < n_batch; ++b) {
565         product_scaling_factors[b] =
566             scaling_factors[b] * aux_input_to_input_weights_scale;
567       }
568       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
569           aux_input_to_input_weights_ptr, n_cell, n_input,
570           quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
571           input_gate_scratch, /*result_stride=*/1);
572     }
573 
574     for (int b = 0; b < n_batch; ++b) {
575       product_scaling_factors[b] =
576           scaling_factors[b] * aux_input_to_forget_weights_scale;
577     }
578     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
579         aux_input_to_forget_weights_ptr, n_cell, n_input,
580         quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
581         forget_gate_scratch, /*result_stride=*/1);
582 
583     for (int b = 0; b < n_batch; ++b) {
584       product_scaling_factors[b] =
585           scaling_factors[b] * aux_input_to_cell_weights_scale;
586     }
587     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
588         aux_input_to_cell_weights_ptr, n_cell, n_input,
589         quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
590         cell_scratch, /*result_stride=*/1);
591 
592     for (int b = 0; b < n_batch; ++b) {
593       product_scaling_factors[b] =
594           scaling_factors[b] * aux_input_to_output_weights_scale;
595     }
596     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
597         aux_input_to_output_weights_ptr, n_cell, n_input,
598         quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
599         output_gate_scratch, /*result_stride=*/1);
600   }
601 
602   if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
603     // Save quantization and matmul computation for all zero input.
604     float unused_min, unused_max;
605     for (int b = 0; b < n_batch; ++b) {
606       const int offset = b * n_output;
607       tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
608                                             quantized_output_state_ptr + offset,
609                                             &unused_min, &unused_max,
610                                             &scaling_factors[b]);
611     }
612     // For each batch and cell: compute recurrent_weight * output_state.
613     if (!use_cifg) {
614       for (int b = 0; b < n_batch; ++b) {
615         product_scaling_factors[b] =
616             scaling_factors[b] * recurrent_to_input_weights_scale;
617       }
618       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
619           recurrent_to_input_weights_ptr, n_cell, n_output,
620           quantized_output_state_ptr, product_scaling_factors, n_batch,
621           input_gate_scratch, /*result_stride=*/1);
622     }
623 
624     for (int b = 0; b < n_batch; ++b) {
625       product_scaling_factors[b] =
626           scaling_factors[b] * recurrent_to_forget_weights_scale;
627     }
628     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
629         recurrent_to_forget_weights_ptr, n_cell, n_output,
630         quantized_output_state_ptr, product_scaling_factors, n_batch,
631         forget_gate_scratch, /*result_stride=*/1);
632 
633     for (int b = 0; b < n_batch; ++b) {
634       product_scaling_factors[b] =
635           scaling_factors[b] * recurrent_to_cell_weights_scale;
636     }
637     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
638         recurrent_to_cell_weights_ptr, n_cell, n_output,
639         quantized_output_state_ptr, product_scaling_factors, n_batch,
640         cell_scratch, /*result_stride=*/1);
641 
642     for (int b = 0; b < n_batch; ++b) {
643       product_scaling_factors[b] =
644           scaling_factors[b] * recurrent_to_output_weights_scale;
645     }
646     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
647         recurrent_to_output_weights_ptr, n_cell, n_output,
648         quantized_output_state_ptr, product_scaling_factors, n_batch,
649         output_gate_scratch, /*result_stride=*/1);
650   }
651 
652   // Save quantization and matmul computation for all zero input.
653   bool is_cell_state_all_zeros =
654       tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
655 
656   // For each batch and cell: update input gate.
657   if (!use_cifg) {
658     if (use_peephole && !is_cell_state_all_zeros) {
659       tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
660                                          cell_to_input_weights_scale,
661                                          recovered_cell_weights);
662       tensor_utils::VectorBatchVectorCwiseProductAccumulate(
663           recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
664           input_gate_scratch);
665     }
666     if (is_layer_norm_lstm) {
667       tensor_utils::MeanStddevNormalization(input_gate_scratch,
668                                             input_gate_scratch, n_cell, n_batch,
669                                             kLayerNormEpsilon);
670       tensor_utils::VectorBatchVectorCwiseProduct(
671           input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
672           n_batch, input_gate_scratch);
673       tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
674                                          input_gate_scratch);
675     }
676     tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
677                                        input_gate_scratch);
678   }
679 
680   // For each batch and cell: update forget gate.
681   if (use_peephole && !is_cell_state_all_zeros) {
682     tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
683                                        cell_to_forget_weights_scale,
684                                        recovered_cell_weights);
685     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
686         recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
687         forget_gate_scratch);
688   }
689   if (is_layer_norm_lstm) {
690     tensor_utils::MeanStddevNormalization(forget_gate_scratch,
691                                           forget_gate_scratch, n_cell, n_batch,
692                                           kLayerNormEpsilon);
693     tensor_utils::VectorBatchVectorCwiseProduct(
694         forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
695         n_batch, forget_gate_scratch);
696     tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
697                                        forget_gate_scratch);
698   }
699   tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
700                                      forget_gate_scratch);
701 
702   // For each batch and cell: update the cell.
703   tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
704                                          n_batch * n_cell, cell_state_ptr);
705   if (is_layer_norm_lstm) {
706     tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
707                                           n_batch, kLayerNormEpsilon);
708     tensor_utils::VectorBatchVectorCwiseProduct(
709         cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
710         cell_scratch);
711     tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
712                                        cell_scratch);
713   }
714   tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
715                                         params->activation, cell_scratch);
716   if (use_cifg) {
717     tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
718                              forget_gate_scratch);
719     tensor_utils::VectorVectorCwiseProductAccumulate(
720         cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
721   } else {
722     tensor_utils::VectorVectorCwiseProductAccumulate(
723         cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
724   }
725   if (params->cell_clip > 0.0) {
726     tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
727                              params->cell_clip, cell_state_ptr);
728   }
729 
730   is_cell_state_all_zeros =
731       tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
732   // For each batch and cell: update the output gate.
733   if (use_peephole && !is_cell_state_all_zeros) {
734     tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
735                                        cell_to_output_weights_scale,
736                                        recovered_cell_weights);
737     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
738         recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
739         output_gate_scratch);
740   }
741   if (is_layer_norm_lstm) {
742     tensor_utils::MeanStddevNormalization(output_gate_scratch,
743                                           output_gate_scratch, n_cell, n_batch,
744                                           kLayerNormEpsilon);
745     tensor_utils::VectorBatchVectorCwiseProduct(
746         output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
747         n_batch, output_gate_scratch);
748     tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
749                                        output_gate_scratch);
750   }
751   tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
752                                      output_gate_scratch);
753   tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
754                                         params->activation, cell_scratch);
755   tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
756                                          n_batch * n_cell, output_gate_scratch);
757 
758   const bool use_projection_weight = (projection_weights_ptr != nullptr);
759   const bool use_projection_bias = (projection_bias_ptr != nullptr);
760 
761   // For each batch: update the projection and output_state. Note that since
762   // the output batch rows may not be contiguous (output_batch_leading_dim !=
763   // n_output), we unroll the batched operations where this is the case.
764   if (output_batch_leading_dim == n_output) {
765     if (use_projection_weight) {
766       if (use_projection_bias) {
767         tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
768                                               n_batch, output_ptr_batch);
769       } else {
770         tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
771       }
772       if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
773         // Save quantization and matmul computation for all zero input.
774         float unused_min, unused_max;
775         for (int b = 0; b < n_batch; ++b) {
776           const int offset = b * n_cell;
777           tensor_utils::SymmetricQuantizeFloats(
778               output_gate_scratch + offset, n_cell,
779               quantized_cell_state_ptr + offset, &unused_min, &unused_max,
780               &scaling_factors[b]);
781         }
782         for (int b = 0; b < n_batch; ++b) {
783           product_scaling_factors[b] =
784               scaling_factors[b] * projection_weights_scale;
785         }
786         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
787             projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
788             product_scaling_factors, n_batch, output_ptr_batch,
789             /*result_stride=*/1);
790       }
791       if (params->proj_clip > 0.0) {
792         tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
793                                  params->proj_clip, output_ptr_batch);
794       }
795     } else {
796       tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
797                                output_ptr_batch);
798     }
799     tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
800                              output_state_ptr);
801   } else {
802     if (use_projection_weight) {
803       if (use_projection_bias) {
804         for (int k = 0; k < n_batch; k++) {
805           tensor_utils::CopyVector(
806               projection_bias_ptr, n_output,
807               output_ptr_batch + k * output_batch_leading_dim);
808         }
809       } else {
810         for (int k = 0; k < n_batch; k++) {
811           tensor_utils::ZeroVector(
812               output_ptr_batch + k * output_batch_leading_dim, n_output);
813         }
814       }
815       if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
816         // Save quantization and matmul computation for all zero input.
817         float unused_min, unused_max;
818         for (int b = 0; b < n_batch; ++b) {
819           const int offset = b * n_cell;
820           tensor_utils::SymmetricQuantizeFloats(
821               output_gate_scratch + offset, n_cell,
822               quantized_cell_state_ptr + offset, &unused_min, &unused_max,
823               &scaling_factors[b]);
824         }
825         for (int b = 0; b < n_batch; ++b) {
826           product_scaling_factors[b] =
827               scaling_factors[b] * projection_weights_scale;
828         }
829         for (int k = 0; k < n_batch; k++) {
830           tensor_utils::MatrixBatchVectorMultiplyAccumulate(
831               projection_weights_ptr, n_output, n_cell,
832               quantized_cell_state_ptr + k * n_cell,
833               &product_scaling_factors[k],
834               /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
835               /*result_stride=*/1);
836         }
837       }
838       if (params->proj_clip > 0.0) {
839         for (int k = 0; k < n_batch; k++) {
840           tensor_utils::ClipVector(
841               output_ptr_batch + k * output_batch_leading_dim, n_output,
842               params->proj_clip,
843               output_ptr_batch + k * output_batch_leading_dim);
844         }
845       }
846     } else {
847       for (int k = 0; k < n_batch; k++) {
848         tensor_utils::CopyVector(
849             output_gate_scratch + k * n_output, n_output,
850             output_ptr_batch + k * output_batch_leading_dim);
851       }
852     }
853     for (int k = 0; k < n_batch; k++) {
854       tensor_utils::CopyVector(output_ptr_batch + k * output_batch_leading_dim,
855                                n_output, output_state_ptr + k * n_output);
856     }
857   }
858 }
859 
GetInt8DataPtr(const TfLiteTensor * tensor,const bool is_uint8)860 int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
861   if (is_uint8) {
862     return reinterpret_cast<int8_t*>(tensor->data.uint8);
863   } else {
864     return tensor->data.int8;
865   }
866 }
867 
868 }  // namespace
869 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * activation_state,TfLiteTensor * cell_state,TfLiteTensor * output)870 TfLiteStatus EvalFloat(
871     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
872     const TfLiteTensor* input_to_forget_weights,
873     const TfLiteTensor* input_to_cell_weights,
874     const TfLiteTensor* input_to_output_weights,
875     const TfLiteTensor* recurrent_to_input_weights,
876     const TfLiteTensor* recurrent_to_forget_weights,
877     const TfLiteTensor* recurrent_to_cell_weights,
878     const TfLiteTensor* recurrent_to_output_weights,
879     const TfLiteTensor* cell_to_input_weights,
880     const TfLiteTensor* cell_to_forget_weights,
881     const TfLiteTensor* cell_to_output_weights,
882     const TfLiteTensor* input_layer_norm_coefficients,
883     const TfLiteTensor* forget_layer_norm_coefficients,
884     const TfLiteTensor* cell_layer_norm_coefficients,
885     const TfLiteTensor* output_layer_norm_coefficients,
886     const TfLiteTensor* aux_input,
887     const TfLiteTensor* aux_input_to_input_weights,
888     const TfLiteTensor* aux_input_to_forget_weights,
889     const TfLiteTensor* aux_input_to_cell_weights,
890     const TfLiteTensor* aux_input_to_output_weights,
891     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
892     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
893     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
894     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
895     int output_offset, TfLiteTensor* scratch_buffer,
896     TfLiteTensor* activation_state, TfLiteTensor* cell_state,
897     TfLiteTensor* output) {
898   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
899   int max_time, n_batch;
900   if (input->dims->size == 3) {
901     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
902     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
903   } else {
904     max_time = 1;
905     n_batch = input->dims->data[0];
906   }
907   const int n_input = input->dims->data[input->dims->size - 1];
908   const int aux_input_size =
909       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
910 
911   // n_cell and n_output will be the same size when there is no projection.
912   const int n_cell = input_to_output_weights->dims->data[0];
913   const int n_output = recurrent_to_output_weights->dims->data[1];
914 
915   // Since we have already checked that weights are all there or none, we can
916   // check the existence of only one to the get the condition.
917   const bool use_cifg = (input_to_input_weights == nullptr);
918   const bool use_peephole = (cell_to_output_weights != nullptr);
919   const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
920 
921   // Index the scratch buffers pointers to the global scratch buffer.
922   float* input_gate_scratch = nullptr;
923   float* cell_scratch = nullptr;
924   float* forget_gate_scratch = nullptr;
925   float* output_gate_scratch = nullptr;
926   if (use_cifg) {
927     cell_scratch = scratch_buffer->data.f;
928     forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
929     output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
930   } else {
931     input_gate_scratch = scratch_buffer->data.f;
932     cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
933     forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
934     output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
935   }
936 
937   // Check optional tensors, the respective pointers can be null.
938   const float* input_to_input_weights_ptr =
939       (use_cifg) ? nullptr : input_to_input_weights->data.f;
940   const float* recurrent_to_input_weights_ptr =
941       (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
942   const float* input_gate_bias_ptr =
943       (use_cifg) ? nullptr : input_gate_bias->data.f;
944   const float* cell_to_input_weights_ptr =
945       (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
946   const float* cell_to_forget_weights_ptr =
947       (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
948   const float* cell_to_output_weights_ptr =
949       (use_peephole) ? cell_to_output_weights->data.f : nullptr;
950   const float* input_layer_norm_coefficients_ptr =
951       (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
952                                         : nullptr;
953   const float* forget_layer_norm_coefficients_ptr =
954       is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
955   const float* cell_layer_norm_coefficients_ptr =
956       is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
957   const float* output_layer_norm_coefficients_ptr =
958       is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
959   const float* projection_weights_ptr =
960       (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
961   const float* projection_bias_ptr =
962       (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
963 
964   float* aux_input_ptr = nullptr;
965   float* aux_input_to_input_weights_ptr = nullptr;
966   float* aux_input_to_forget_weights_ptr = nullptr;
967   float* aux_input_to_cell_weights_ptr = nullptr;
968   float* aux_input_to_output_weights_ptr = nullptr;
969   if (aux_input_size > 0) {
970     if (!use_cifg) {
971       aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
972     }
973     aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
974     aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
975     aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
976   }
977 
978   const int output_batch_leading_dim =
979       output->dims->data[output->dims->size - 1];
980   if (time_major) {
981     // Loop through the sequence.
982     const int input_step = n_batch * n_input;
983     const int output_step = n_batch * output_batch_leading_dim;
984     for (int t = 0; t < max_time; t++) {
985       // If this is the forward_sequence, step forward, otherwise step
986       // backwards.
987       const int t_rel = forward_sequence ? t : max_time - t - 1;
988       const float* input_ptr_batch = input->data.f + t_rel * input_step;
989       if (aux_input) {
990         aux_input_ptr = aux_input->data.f + t_rel * input_step;
991       }
992       float* output_ptr_time =
993           output->data.f + t_rel * output_step + output_offset;
994 
995       LstmStepWithAuxInput(
996           input_ptr_batch, input_to_input_weights_ptr,
997           input_to_forget_weights->data.f, input_to_cell_weights->data.f,
998           input_to_output_weights->data.f, aux_input_ptr,
999           aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1000           aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1001           recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
1002           recurrent_to_cell_weights->data.f,
1003           recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
1004           cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
1005           input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
1006           cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
1007           input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
1008           output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
1009           params, n_batch, n_cell, n_input, aux_input_size, n_output,
1010           output_batch_leading_dim, activation_state->data.f,
1011           cell_state->data.f, input_gate_scratch, forget_gate_scratch,
1012           cell_scratch, output_gate_scratch, output_ptr_time);
1013     }
1014   } else {
1015     for (int b = 0; b < n_batch; b++) {
1016       const int input_step = n_input;
1017       const int output_step = output_batch_leading_dim;
1018       for (int t = 0; t < max_time; t++) {
1019         // If this is the forward_sequence, step forward, otherwise step
1020         // backwards.
1021         const int t_rel = forward_sequence ? t : max_time - t - 1;
1022         const int time_offset = b * max_time + t_rel;
1023         const float* input_ptr = input->data.f + time_offset * input_step;
1024         if (aux_input) {
1025           aux_input_ptr = aux_input->data.f + time_offset * input_step;
1026         }
1027         float* output_ptr =
1028             output->data.f + time_offset * output_step + output_offset;
1029 
1030         // Offset the {activation,cell}_state pointers to the right batch.
1031         float* activation_state_ptr =
1032             activation_state->data.f + b * output_batch_leading_dim;
1033         float* cell_state_ptr = cell_state->data.f + b * n_cell;
1034         // Offset the scratch pointers to the right batch.
1035         float* input_gate_scratch_ptr =
1036             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1037         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1038         float* cell_scratch_ptr = cell_scratch + b * n_cell;
1039         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1040 
1041         LstmStepWithAuxInput(
1042             input_ptr, input_to_input_weights_ptr,
1043             input_to_forget_weights->data.f, input_to_cell_weights->data.f,
1044             input_to_output_weights->data.f, aux_input_ptr,
1045             aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1046             aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1047             recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
1048             recurrent_to_cell_weights->data.f,
1049             recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
1050             cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
1051             input_layer_norm_coefficients_ptr,
1052             forget_layer_norm_coefficients_ptr,
1053             cell_layer_norm_coefficients_ptr,
1054             output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
1055             forget_gate_bias->data.f, cell_bias->data.f,
1056             output_gate_bias->data.f, projection_weights_ptr,
1057             projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
1058             aux_input_size, n_output, output_batch_leading_dim,
1059             activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1060             forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
1061             output_ptr);
1062       }
1063     }
1064   }
1065   return kTfLiteOk;
1066 }
1067 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * scaling_factors,TfLiteTensor * prod_scaling_factors,TfLiteTensor * recovered_cell_weights,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * output_state_quantized,TfLiteTensor * cell_state_quantized,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output)1068 TfLiteStatus EvalHybrid(
1069     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1070     const TfLiteTensor* input_to_forget_weights,
1071     const TfLiteTensor* input_to_cell_weights,
1072     const TfLiteTensor* input_to_output_weights,
1073     const TfLiteTensor* recurrent_to_input_weights,
1074     const TfLiteTensor* recurrent_to_forget_weights,
1075     const TfLiteTensor* recurrent_to_cell_weights,
1076     const TfLiteTensor* recurrent_to_output_weights,
1077     const TfLiteTensor* cell_to_input_weights,
1078     const TfLiteTensor* cell_to_forget_weights,
1079     const TfLiteTensor* cell_to_output_weights,
1080     const TfLiteTensor* input_layer_norm_coefficients,
1081     const TfLiteTensor* forget_layer_norm_coefficients,
1082     const TfLiteTensor* cell_layer_norm_coefficients,
1083     const TfLiteTensor* output_layer_norm_coefficients,
1084     const TfLiteTensor* aux_input,
1085     const TfLiteTensor* aux_input_to_input_weights,
1086     const TfLiteTensor* aux_input_to_forget_weights,
1087     const TfLiteTensor* aux_input_to_cell_weights,
1088     const TfLiteTensor* aux_input_to_output_weights,
1089     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1090     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
1091     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1092     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1093     int output_offset, TfLiteTensor* scratch_buffer,
1094     TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
1095     TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
1096     TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
1097     TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
1098     TfLiteTensor* cell_state, TfLiteTensor* output) {
1099   // For operations that use int8 instead of uint8 we need to fetch raw data
1100   // from the tensor different. We use this bool for that condition.
1101   const bool is_uint8_hybrid = input_to_output_weights->type == kTfLiteUInt8;
1102   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1103   const int n_input = input->dims->data[input->dims->size - 1];
1104   int max_time, n_batch;
1105   if (input->dims->size == 2) {
1106     max_time = 1;
1107     n_batch = input->dims->data[0];
1108   } else {
1109     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1110     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1111   }
1112   const int aux_input_size =
1113       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1114   // n_cell and n_output will be the same size when there is no projection.
1115   const int n_cell = input_to_output_weights->dims->data[0];
1116   const int n_output = recurrent_to_output_weights->dims->data[1];
1117 
1118   // Since we have already checked that weights are all there or none, we can
1119   // check the existence of only one to get the condition.
1120   const bool use_cifg = (input_to_input_weights == nullptr);
1121   const bool use_peephole = (cell_to_output_weights != nullptr);
1122   const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
1123 
1124   float* input_gate_scratch = nullptr;
1125   float* cell_scratch = nullptr;
1126   float* forget_gate_scratch = nullptr;
1127   float* output_gate_scratch = nullptr;
1128   if (use_cifg) {
1129     cell_scratch = scratch_buffer->data.f;
1130     forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
1131     output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
1132   } else {
1133     input_gate_scratch = scratch_buffer->data.f;
1134     cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
1135     forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
1136     output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
1137   }
1138 
1139   // Check optional tensors, the respective pointers can be null.
1140   int8_t* input_to_input_weights_ptr = nullptr;
1141   float input_to_input_weights_scale = 1.0f;
1142   int8_t* recurrent_to_input_weights_ptr = nullptr;
1143   float recurrent_to_input_weights_scale = 1.0f;
1144   float* input_gate_bias_ptr = nullptr;
1145   if (!use_cifg) {
1146     input_to_input_weights_ptr =
1147         GetInt8DataPtr(input_to_input_weights, is_uint8_hybrid);
1148     recurrent_to_input_weights_ptr =
1149         GetInt8DataPtr(recurrent_to_input_weights, is_uint8_hybrid);
1150     input_gate_bias_ptr = input_gate_bias->data.f;
1151     input_to_input_weights_scale = input_to_input_weights->params.scale;
1152     recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
1153   }
1154 
1155   int8_t* cell_to_input_weights_ptr = nullptr;
1156   int8_t* cell_to_forget_weights_ptr = nullptr;
1157   int8_t* cell_to_output_weights_ptr = nullptr;
1158   float cell_to_input_weights_scale = 1.0f;
1159   float cell_to_forget_weights_scale = 1.0f;
1160   float cell_to_output_weights_scale = 1.0f;
1161   if (use_peephole) {
1162     if (!use_cifg) {
1163       cell_to_input_weights_ptr =
1164           GetInt8DataPtr(cell_to_input_weights, is_uint8_hybrid);
1165       cell_to_input_weights_scale = cell_to_input_weights->params.scale;
1166     }
1167     cell_to_forget_weights_ptr =
1168         GetInt8DataPtr(cell_to_forget_weights, is_uint8_hybrid);
1169     cell_to_output_weights_ptr =
1170         GetInt8DataPtr(cell_to_output_weights, is_uint8_hybrid);
1171     cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
1172     cell_to_output_weights_scale = cell_to_output_weights->params.scale;
1173   }
1174 
1175   const float* input_layer_norm_coefficients_ptr =
1176       (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
1177                                         : nullptr;
1178   const float* forget_layer_norm_coefficients_ptr =
1179       is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
1180   const float* cell_layer_norm_coefficients_ptr =
1181       is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
1182   const float* output_layer_norm_coefficients_ptr =
1183       is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
1184 
1185   const int8_t* projection_weights_ptr =
1186       (projection_weights == nullptr)
1187           ? nullptr
1188           : GetInt8DataPtr(projection_weights, is_uint8_hybrid);
1189   const float projection_weights_scale =
1190       (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
1191   const float* projection_bias_ptr =
1192       (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
1193 
1194   // Required tensors, pointers are non-null.
1195   const int8_t* input_to_forget_weights_ptr =
1196       GetInt8DataPtr(input_to_forget_weights, is_uint8_hybrid);
1197   const float input_to_forget_weights_scale =
1198       input_to_forget_weights->params.scale;
1199   const int8_t* input_to_cell_weights_ptr =
1200       GetInt8DataPtr(input_to_cell_weights, is_uint8_hybrid);
1201   const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
1202   const int8_t* input_to_output_weights_ptr =
1203       GetInt8DataPtr(input_to_output_weights, is_uint8_hybrid);
1204   const float input_to_output_weights_scale =
1205       input_to_output_weights->params.scale;
1206   const int8_t* recurrent_to_forget_weights_ptr =
1207       GetInt8DataPtr(recurrent_to_forget_weights, is_uint8_hybrid);
1208   const float recurrent_to_forget_weights_scale =
1209       recurrent_to_forget_weights->params.scale;
1210   const int8_t* recurrent_to_cell_weights_ptr =
1211       GetInt8DataPtr(recurrent_to_cell_weights, is_uint8_hybrid);
1212   const float recurrent_to_cell_weights_scale =
1213       recurrent_to_cell_weights->params.scale;
1214   const int8_t* recurrent_to_output_weights_ptr =
1215       GetInt8DataPtr(recurrent_to_output_weights, is_uint8_hybrid);
1216   const float recurrent_to_output_weights_scale =
1217       recurrent_to_output_weights->params.scale;
1218   const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
1219   const float* cell_bias_ptr = cell_bias->data.f;
1220   const float* output_gate_bias_ptr = output_gate_bias->data.f;
1221 
1222   // Temporary storage for quantized values and scaling factors.
1223   int8_t* quantized_input_ptr =
1224       GetInt8DataPtr(input_quantized, is_uint8_hybrid);
1225   int8_t* quantized_aux_input_ptr =
1226       (aux_input_quantized == nullptr)
1227           ? nullptr
1228           : GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid);
1229   int8_t* quantized_output_state_ptr =
1230       GetInt8DataPtr(output_state_quantized, is_uint8_hybrid);
1231   int8_t* quantized_cell_state_ptr =
1232       GetInt8DataPtr(cell_state_quantized, is_uint8_hybrid);
1233   float* scaling_factors_ptr = scaling_factors->data.f;
1234   float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
1235   float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
1236 
1237   // Auxiliary input and weights.
1238   float* aux_input_ptr = nullptr;
1239   int8_t* aux_input_to_input_weights_ptr = nullptr;
1240   int8_t* aux_input_to_forget_weights_ptr = nullptr;
1241   int8_t* aux_input_to_cell_weights_ptr = nullptr;
1242   int8_t* aux_input_to_output_weights_ptr = nullptr;
1243   float aux_input_to_input_weights_scale = 0.0f;
1244   float aux_input_to_forget_weights_scale = 0.0f;
1245   float aux_input_to_cell_weights_scale = 0.0f;
1246   float aux_input_to_output_weights_scale = 0.0f;
1247   if (aux_input_size > 0) {
1248     if (!use_cifg) {
1249       aux_input_to_input_weights_ptr =
1250           GetInt8DataPtr(aux_input_to_input_weights, is_uint8_hybrid);
1251     }
1252     aux_input_to_forget_weights_ptr =
1253         GetInt8DataPtr(aux_input_to_forget_weights, is_uint8_hybrid);
1254     aux_input_to_cell_weights_ptr =
1255         GetInt8DataPtr(aux_input_to_cell_weights, is_uint8_hybrid);
1256     aux_input_to_output_weights_ptr =
1257         GetInt8DataPtr(aux_input_to_output_weights, is_uint8_hybrid);
1258     if (!use_cifg) {
1259       aux_input_to_input_weights_scale =
1260           aux_input_to_input_weights->params.scale;
1261     }
1262     aux_input_to_forget_weights_scale =
1263         aux_input_to_forget_weights->params.scale;
1264     aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
1265     aux_input_to_output_weights_scale =
1266         aux_input_to_output_weights->params.scale;
1267   }
1268 
1269   const int output_batch_leading_dim =
1270       output->dims->data[output->dims->size - 1];
1271   if (time_major) {
1272     // Feed the sequence into the LSTM step-by-step.
1273     const int input_step = n_batch * n_input;
1274     const int output_step = n_batch * output_batch_leading_dim;
1275     for (int t = 0; t < max_time; t++) {
1276       // If this is the forward_sequence, step forward, otherwise step
1277       // backwards.
1278       const int t_rel = forward_sequence ? t : max_time - t - 1;
1279       const float* input_ptr_batch = input->data.f + t_rel * input_step;
1280       if (aux_input) {
1281         aux_input_ptr = aux_input->data.f + t_rel * input_step;
1282       }
1283       float* output_ptr_batch =
1284           output->data.f + t_rel * output_step + output_offset;
1285 
1286       LstmStepWithAuxInput(
1287           input_ptr_batch, input_to_input_weights_ptr,
1288           input_to_input_weights_scale, input_to_forget_weights_ptr,
1289           input_to_forget_weights_scale, input_to_cell_weights_ptr,
1290           input_to_cell_weights_scale, input_to_output_weights_ptr,
1291           input_to_output_weights_scale, aux_input_ptr,
1292           aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
1293           aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
1294           aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
1295           aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
1296           recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
1297           recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
1298           recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
1299           recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
1300           cell_to_input_weights_ptr, cell_to_input_weights_scale,
1301           cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1302           cell_to_output_weights_ptr, cell_to_output_weights_scale,
1303           input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
1304           cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
1305           input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
1306           output_gate_bias_ptr, projection_weights_ptr,
1307           projection_weights_scale, projection_bias_ptr, params, n_batch,
1308           n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1309           input_gate_scratch, forget_gate_scratch, cell_scratch,
1310           output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
1311           recovered_cell_weights_ptr, quantized_input_ptr,
1312           quantized_aux_input_ptr, quantized_output_state_ptr,
1313           quantized_cell_state_ptr, output_state->data.f, cell_state->data.f,
1314           output_ptr_batch);
1315     }
1316   } else {
1317     for (int b = 0; b < n_batch; b++) {
1318       const int input_step = n_input;
1319       const int output_step = output_batch_leading_dim;
1320       for (int t = 0; t < max_time; t++) {
1321         // If this is the forward_sequence, step forward, otherwise step
1322         // backwards.
1323         const int t_rel = forward_sequence ? t : max_time - t - 1;
1324         const int time_offset = b * max_time + t_rel;
1325         const float* input_ptr = input->data.f + time_offset * input_step;
1326         if (aux_input) {
1327           aux_input_ptr = aux_input->data.f + time_offset * input_step;
1328         }
1329         float* output_ptr =
1330             output->data.f + time_offset * output_step + output_offset;
1331 
1332         // Offset the {output,cell}_state pointers to the right batch.
1333         float* output_state_ptr =
1334             output_state->data.f + b * output_batch_leading_dim;
1335         float* cell_state_ptr = cell_state->data.f + b * n_cell;
1336         // Offset the scratch pointers to the right batch.
1337         float* input_gate_scratch_ptr =
1338             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1339         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1340         float* cell_scratch_ptr = cell_scratch + b * n_cell;
1341         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1342 
1343         LstmStepWithAuxInput(
1344             input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
1345             input_to_forget_weights_ptr, input_to_forget_weights_scale,
1346             input_to_cell_weights_ptr, input_to_cell_weights_scale,
1347             input_to_output_weights_ptr, input_to_output_weights_scale,
1348             aux_input_ptr, aux_input_to_input_weights_ptr,
1349             aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
1350             aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
1351             aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
1352             aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
1353             recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
1354             recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
1355             recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
1356             recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
1357             cell_to_input_weights_scale, cell_to_forget_weights_ptr,
1358             cell_to_forget_weights_scale, cell_to_output_weights_ptr,
1359             cell_to_output_weights_scale, input_layer_norm_coefficients_ptr,
1360             forget_layer_norm_coefficients_ptr,
1361             cell_layer_norm_coefficients_ptr,
1362             output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
1363             forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
1364             projection_weights_ptr, projection_weights_scale,
1365             projection_bias_ptr, params,
1366             /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
1367             output_batch_leading_dim, input_gate_scratch_ptr,
1368             forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
1369             scaling_factors_ptr, prod_scaling_factors_ptr,
1370             recovered_cell_weights_ptr, quantized_input_ptr,
1371             quantized_aux_input_ptr, quantized_output_state_ptr,
1372             quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
1373             output_ptr);
1374       }
1375     }
1376   }
1377 
1378   return kTfLiteOk;
1379 }
1380 
1381 }  // namespace lstm_eval
1382 }  // namespace builtin
1383 }  // namespace ops
1384 }  // namespace tflite
1385