1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16 
17 #include <math.h>
18 #include <string.h>
19 
20 #include <algorithm>
21 #include <cstdint>
22 #include <memory>
23 #include <vector>
24 
25 #include "ruy/profiler/instrumentation.h"  // from @ruy
26 #include "tensorflow/lite/c/builtin_op_data.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/kernels/cpu_backend_context.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
33 #include "tensorflow/lite/kernels/op_macros.h"
34 
35 namespace tflite {
36 namespace ops {
37 namespace builtin {
38 namespace lstm_eval {
39 namespace {
40 
ComputeRowSums(int32_t * input_to_input_row_sums,int32_t * input_to_forget_row_sums,int32_t * input_to_cell_row_sums,int32_t * input_to_output_row_sums,int32_t * aux_input_to_input_row_sums,int32_t * aux_input_to_forget_row_sums,int32_t * aux_input_to_cell_row_sums,int32_t * aux_input_to_output_row_sums,int32_t * recurrent_to_input_row_sums,int32_t * recurrent_to_forget_row_sums,int32_t * recurrent_to_cell_row_sums,int32_t * recurrent_to_output_row_sums,int32_t * projection_weights_row_sums,int32_t * row_sums,int n_cell,int n_input,int n_aux_input,int n_output,const int8_t * input_to_input_weights_ptr,const int8_t * input_to_forget_weights_ptr,const int8_t * input_to_cell_weights_ptr,const int8_t * input_to_output_weights_ptr,const int8_t * aux_input_to_input_weights_ptr,const int8_t * aux_input_to_forget_weights_ptr,const int8_t * aux_input_to_cell_weights_ptr,const int8_t * aux_input_to_output_weights_ptr,const int8_t * recurrent_to_input_weights_ptr,const int8_t * recurrent_to_forget_weights_ptr,const int8_t * recurrent_to_cell_weights_ptr,const int8_t * recurrent_to_output_weights_ptr,const int8_t * projection_weights_ptr,bool use_cifg,const float * aux_input_ptr)41 void ComputeRowSums(
42     int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
43     int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
44     int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
45     int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
46     int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
47     int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
48     int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
49     int n_input, int n_aux_input, int n_output,
50     const int8_t* input_to_input_weights_ptr,
51     const int8_t* input_to_forget_weights_ptr,
52     const int8_t* input_to_cell_weights_ptr,
53     const int8_t* input_to_output_weights_ptr,
54     const int8_t* aux_input_to_input_weights_ptr,
55     const int8_t* aux_input_to_forget_weights_ptr,
56     const int8_t* aux_input_to_cell_weights_ptr,
57     const int8_t* aux_input_to_output_weights_ptr,
58     const int8_t* recurrent_to_input_weights_ptr,
59     const int8_t* recurrent_to_forget_weights_ptr,
60     const int8_t* recurrent_to_cell_weights_ptr,
61     const int8_t* recurrent_to_output_weights_ptr,
62     const int8_t* projection_weights_ptr, bool use_cifg,
63     const float* aux_input_ptr) {
64   // Compute the row sums for dequantization
65   if (!use_cifg) {
66     tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
67                                      input_to_input_row_sums, n_cell, n_input);
68   }
69   tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
70                                    input_to_forget_row_sums, n_cell, n_input);
71   tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
72                                    input_to_cell_row_sums, n_cell, n_input);
73   tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
74                                    input_to_output_row_sums, n_cell, n_input);
75 
76   if (aux_input_ptr) {
77     if (!use_cifg) {
78       tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
79                                        aux_input_to_input_row_sums, n_cell,
80                                        n_aux_input);
81     }
82     tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
83                                      aux_input_to_forget_row_sums, n_cell,
84                                      n_aux_input);
85     tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
86                                      aux_input_to_cell_row_sums, n_cell,
87                                      n_aux_input);
88     tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
89                                      aux_input_to_output_row_sums, n_cell,
90                                      n_aux_input);
91   }
92   if (!use_cifg) {
93     tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
94                                      recurrent_to_input_row_sums, n_cell,
95                                      n_output);
96   }
97   tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
98                                    recurrent_to_forget_row_sums, n_cell,
99                                    n_output);
100   tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
101                                    recurrent_to_cell_row_sums, n_cell,
102                                    n_output);
103   tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
104                                    recurrent_to_output_row_sums, n_cell,
105                                    n_output);
106 
107   if (projection_weights_ptr != nullptr) {
108     tensor_utils::ReductionSumVector(
109         projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
110   }
111 }
112 
GetTensorScale(const TfLiteTensor * tensor)113 inline float GetTensorScale(const TfLiteTensor* tensor) {
114   return tensor == nullptr ? 1.0f : tensor->params.scale;
115 }
116 
117 // LINT.IfChange
118 // Calculates a single LSTM gate.
119 //
120 // Implements the following formula: (* is matrix multiply)
121 //   gate = activate(W_input    * input + W_aux       * aux_input   +
122 //                   W_peephole * cell  + W_recurrent * prev_output + bias)
123 // with layer norm:
124 //   gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
125 //
126 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
127 //
128 // Parameters:
129 // Input vectors (to LSTM):    | Size:                | Optional?
130 //   input                     | n_input              |
131 //   aux_input                 | n_aux_input          | y (bidir LSTM)
132 // Input vectors (persistent states):
133 //   output_state              | n_output             |
134 //   cell_state                | n_cell               |
135 // 'Constant' inputs:
136 //   input_to_gate_weights     | n_cell * n_input     |
137 //   aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
138 //   recurrent_to_gate_weights | n_cell * n_output    |
139 //   cell_to_gate_weights      | n_cell               | y (peephole)
140 //   gate_bias                 | n_cell               |
141 //   layer_norm_coefficients   | n_cell               | y (layer norm)
142 // Output vector:
143 //   gate                      | n_cell               |
144 // Scalar parameters:
145 //   n_batch                                    - batch size / number of vectors
146 //   n_input, n_aux_input, n_output, n_cell     - size of vectors.
147 //   activation                                 - activation to use.
148 //   is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
149 //   use_layer_norm                             - if doing layer norm LSTM.
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros)150 inline void CalculateLstmGateFloat(
151     const float* input, const float* input_to_gate_weights,
152     const float* aux_input, const float* aux_input_to_gate_weights,
153     const float* output_state, const float* recurrent_to_gate_weights,
154     const float* cell_state, const float* cell_to_gate_weights,
155     const float* layer_norm_coefficients, const float* gate_bias,
156     const int n_batch, const int n_input, const int n_aux_input,
157     const int n_output, const int n_cell,
158     const TfLiteFusedActivation activation, float* gate,
159     const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
160   const bool use_peephole = (cell_to_gate_weights != nullptr);
161   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
162 
163   // Initialize scratch buffers with bias for regular lstm or initialize with
164   // zero for layer norm lstm.
165   if (use_layer_norm) {
166     std::fill_n(gate, n_cell * n_batch, 0.0f);
167   } else {
168     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
169   }
170   // For each batch and cell: compute input_weight * input.
171   // Skip if input is all zeros.
172   if (!is_input_all_zeros) {
173     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
174         input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
175   }
176   // For each batch and cell: compute aux_input_weight * aux_input.
177   // Skip if auxiliary input is not available or all zeros.
178   if (!is_aux_input_all_zeros) {
179     tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
180                                                       n_cell, n_aux_input,
181                                                       aux_input, n_batch, gate);
182   }
183   // For each batch and cell: compute recurrent_weight * output_state.
184   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
185       recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
186   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
187   if (use_peephole) {
188     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
189         cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
190   }
191   // Do layer normalization (if layer norm LSTM)
192   if (use_layer_norm) {
193     tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
194     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
195                                                 gate, n_batch, gate);
196     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
197   }
198   // Apply activation
199   tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
200                                         gate);
201 }
202 
203 // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
204 //
205 // Implements the following formula:
206 //   cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
207 //
208 // With CIFG LSTM, input gate is replaced by (1-forget_gate).
209 //
210 // Parameters:
211 //  - n_batch, n_cell: sizes of vectors
212 //  - cell_state: input/output vector, size n_batch*n_cell
213 //  - input_gate: input vector, size n_batch*n_cell.
214 //  - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
215 //  - cell_gate: input vector, size n_batch*n_cell.
216 //  - use_cifg: use 1-forget_gate instead of input_gate.
217 //  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)218 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
219                          const float* input_gate, float* forget_gate,
220                          const float* cell_gate, bool use_cifg, float clip) {
221   tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
222                                          n_batch * n_cell, cell_state);
223 
224   if (use_cifg) {
225     // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
226     // scratch, as input_gate array is not allocated in this case. (Be careful
227     // not to write to the scratch before reading the forget gate data.)
228     float* scratch = forget_gate;
229     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
230     tensor_utils::VectorVectorCwiseProductAccumulate(
231         cell_gate, scratch, n_batch * n_cell, cell_state);
232   } else {
233     tensor_utils::VectorVectorCwiseProductAccumulate(
234         cell_gate, input_gate, n_batch * n_cell, cell_state);
235   }
236   if (clip > 0.0f) {
237     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
238   }
239 }
240 
241 // Calculates the output state tensor of an LSTM step.
242 //
243 // Implements the following formula:
244 //   output_no_projection = output_gate .* activate(cell_state)
245 //     (elementwise vector product)
246 // If no projection is used:
247 //   output = output_state = output_no_projection
248 // With projection:
249 //   output = output_state = clip(W*output_no_projection + bias)
250 //
251 // Output might not have a different 'stride' than n_batch, so we need to copy.
252 //
253 // Parameters:
254 //  - n_batch: batches: the number of distinct vectors in each array.
255 //  - n_cell, n_output: sizes of vectors.
256 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
257 //  - projection_weights, projection_weights_scale, projection_bias:
258 //      constant inputs, describing projection matrix and bias.
259 //  - proj_clip: if > 0, clip the output of the projection.
260 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
261 //  - scratch: scratch area, size n_batch*n_cell.
CalculateLstmOutputFloat(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch)262 void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
263                               const float* cell_state, const float* output_gate,
264                               TfLiteFusedActivation activation,
265                               const float* projection_weights,
266                               const float* projection_bias,
267                               const float proj_clip, float* output_state,
268                               float* scratch) {
269   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
270                                         activation, scratch);
271   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
272                                          scratch);
273 
274   const bool use_projection = (projection_weights != nullptr);
275   const bool use_projection_bias = (projection_bias != nullptr);
276 
277   if (use_projection) {
278     if (use_projection_bias) {
279       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
280                                             output_state);
281     } else {
282       std::fill_n(output_state, n_batch * n_output, 0.0f);
283     }
284     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
285         projection_weights, n_output, n_cell, scratch, n_batch, output_state);
286     if (proj_clip > 0.0f) {
287       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
288     }
289   } else {
290     std::copy_n(scratch, n_batch * n_output, output_state);
291   }
292 }
293 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
294 //                 ../experimental/kernels/fp16/lstm_eval.cc)
295 
296 // Calculates a single LSTM gate, hybrid version.
297 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateHybrid(const int8_t * input,const float * input_sf,const int32_t * input_zp,const int8_t * input_to_gate_weights,const uint8_t * input_to_gate_weights_ledger,const float input_to_gate_weights_scale,int32_t * input_to_gate_row_sums,const int8_t * aux_input,const float * aux_input_sf,const int32_t * aux_input_zp,const int8_t * aux_input_to_gate_weights,const float aux_input_to_gate_weights_scale,int32_t * aux_input_to_gate_row_sums,const int8_t * output_state,const float * output_state_sf,const int32_t * output_state_zp,const int8_t * recurrent_to_gate_weights,const uint8_t * recurrent_to_gate_weights_ledger,const float recurrent_to_gate_weights_scale,int32_t * recurrent_to_gate_row_sums,const float * cell_state,const int8_t * cell_to_gate_weights,const float cell_to_gate_weights_scale,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,const bool is_output_state_all_zeros,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,float * scratch1,int32_t * accum_scratch)298 void CalculateLstmGateHybrid(
299     // Input and weights
300     const int8_t* input, const float* input_sf, const int32_t* input_zp,
301     const int8_t* input_to_gate_weights,
302     const uint8_t* input_to_gate_weights_ledger,
303     const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
304     // Aux input and weights
305     const int8_t* aux_input, const float* aux_input_sf,
306     const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
307     const float aux_input_to_gate_weights_scale,
308     int32_t* aux_input_to_gate_row_sums,
309     // Output state and weights
310     const int8_t* output_state, const float* output_state_sf,
311     const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
312     const uint8_t* recurrent_to_gate_weights_ledger,
313     const float recurrent_to_gate_weights_scale,
314     int32_t* recurrent_to_gate_row_sums,
315     // Cell state and weights (peephole LSTM)
316     const float* cell_state, const int8_t* cell_to_gate_weights,
317     const float cell_to_gate_weights_scale,
318     // Layer normalization coefficients (layer norm LSTM) + gate bias
319     const float* layer_norm_coefficients, const float* gate_bias,
320     // Array sizes
321     const int n_batch, const int n_input, const int n_aux_input,
322     const int n_output, const int n_cell,
323     const TfLiteFusedActivation activation,
324     // Output
325     float* gate,
326     // Parameters for performance optimizations
327     const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
328     const bool is_output_state_all_zeros, bool* compute_row_sums,
329     CpuBackendContext* context,
330     // Scratch arrays
331     float* scratch0,        // size: n_batch
332     float* scratch1,        // size: n_cell, only used if peephole LSTM
333     int32_t* accum_scratch  // For MatrixBatchVectorMultiplyAccumulate
334 ) {
335   const bool use_peephole = (cell_to_gate_weights != nullptr);
336   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
337 
338   // Initialize scratch buffers with bias for regular lstm or initialize with
339   // zero for layer norm lstm.
340   if (use_layer_norm) {
341     std::fill_n(gate, n_cell * n_batch, 0.0f);
342   } else {
343     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
344   }
345   // For each batch and cell: compute input_weight * input.
346   // Skip if input is all zeros.
347   if (!is_input_all_zeros) {
348     if (input_to_gate_weights_ledger != nullptr) {
349       std::vector<float> scales(n_batch);
350       for (int i = 0; i < n_batch; i++) {
351         scales[i] = input_to_gate_weights_scale * input_sf[i];
352       }
353       tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
354           input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
355           input, scales.data(), n_batch, gate);
356 
357     } else {
358       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
359           input_to_gate_weights, n_cell, n_input, input,
360           input_to_gate_weights_scale, input_sf, n_batch, gate,
361           /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
362           input_to_gate_row_sums, compute_row_sums, scratch0, context);
363     }
364   }
365   // For each batch and cell: compute aux_input_weight * aux_input.
366   // Skip if auxiliary input is not available or all zeros.
367   if (!is_aux_input_all_zeros) {
368     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
369         aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
370         aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
371         /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
372         aux_input_to_gate_row_sums, compute_row_sums, scratch0, context);
373   }
374   // For each batch and cell: compute recurrent_weight * output_state.
375   // Skip if output state is all zeros.
376   if (!is_output_state_all_zeros) {
377     if (recurrent_to_gate_weights_ledger != nullptr) {
378       std::vector<float> scales(n_batch);
379       for (int i = 0; i < n_batch; i++) {
380         scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
381       }
382       tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
383           recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
384           n_output, output_state, scales.data(), n_batch, gate);
385     } else {
386       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
387           recurrent_to_gate_weights, n_cell, n_output, output_state,
388           recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
389           /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
390           recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
391     }
392   }
393   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
394   if (use_peephole) {
395     float* recovered_cell_weights = scratch1;
396     tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
397                                        cell_to_gate_weights_scale,
398                                        recovered_cell_weights);
399     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
400         recovered_cell_weights, n_cell, cell_state, n_batch, gate);
401   }
402   // Do layer normalization (if layer norm LSTM)
403   if (use_layer_norm) {
404     tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
405     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
406                                                 gate, n_batch, gate);
407     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
408   }
409   // Apply activation
410   tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation,
411                                         gate);
412 }
413 
414 // Calculates the output state tensor of an LSTM step. See Float version too.
415 //
416 // Parameters:
417 //  - n_batch: batches: the number of distinct vectors in each array.
418 //  - n_cell, n_output: sizes of vectors.
419 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
420 //  - projection_weights, projection_weights_scale, projection_bias:
421 //      constant inputs, describing projection matrix and bias.
422 //  - proj_clip: if > 0, clip the output of the projection.
423 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
424 //  - asymmetric_quantize_inputs: parameter to control quantization.
425 //  - projection_weights_row_sums, compute_row_sums, context: Data for optimized
426 //      MatrixBatchVectorMultiplyAccumulate.
427 //  - scratch0: scratch area of size n_batch*n_cell
428 //  - scratch1: scratch area of size n_batch*n_cell
429 //  - scratch2: scratch area of size n_batch
430 //  - scratch3: scratch area of size n_batch
431 //  - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputHybrid(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const int8_t * projection_weights,const uint8_t * projection_weights_ledger,float projection_weights_scale,const float * projection_bias,const float proj_clip,float * output_state,bool asymmetric_quantize_inputs,int32_t * projection_weights_row_sums,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,int8_t * scratch1,float * scratch2,int32_t * scratch3,int32_t * scratch4)432 void CalculateLstmOutputHybrid(
433     int n_batch, int n_cell, int n_output, const float* cell_state,
434     const float* output_gate, TfLiteFusedActivation activation,
435     const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
436     float projection_weights_scale, const float* projection_bias,
437     const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
438     int32_t* projection_weights_row_sums, bool* compute_row_sums,
439     CpuBackendContext* context, float* scratch0, int8_t* scratch1,
440     float* scratch2, int32_t* scratch3, int32_t* scratch4) {
441   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
442                                         activation, scratch0);
443   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
444                                          n_batch * n_cell, scratch0);
445 
446   const bool use_projection = (projection_weights != nullptr);
447   const bool use_projection_bias = (projection_bias != nullptr);
448 
449   if (use_projection) {
450     if (use_projection_bias) {
451       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
452                                             output_state);
453     } else {
454       std::fill_n(output_state, n_batch * n_output, 0.0f);
455     }
456     if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
457       // Save quantization and matmul computation for all zero output.
458       tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
459                                         scratch2, scratch3,
460                                         asymmetric_quantize_inputs);
461       if (projection_weights_ledger != nullptr) {
462         std::vector<float> scales(n_batch);
463         for (int i = 0; i < n_batch; i++) {
464           scales[i] = projection_weights_scale * scratch2[i];
465         }
466         tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
467             projection_weights, projection_weights_ledger, n_output, n_cell,
468             scratch1, scales.data(), n_batch, output_state);
469       } else {
470         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
471             projection_weights, n_output, n_cell, scratch1,
472             projection_weights_scale, scratch2, n_batch, output_state,
473             /*per_channel_scale=*/nullptr, scratch3, scratch4,
474             projection_weights_row_sums, compute_row_sums, scratch2, context);
475       }
476     }
477     if (proj_clip > 0.0f) {
478       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
479     }
480   } else {
481     std::copy_n(scratch0, n_batch * n_output, output_state);
482   }
483 }
484 
485 // Calculates a single LSTM gate, int8x8_16 version.
486 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_16(const int8_t * input,const int8_t * input_to_gate_weights,const int32_t * input_to_gate_bias,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int8_t * output_state,const int8_t * recurrent_to_gate_weights,const int32_t * recurrent_to_gate_bias,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int16_t * cell_state,const int16_t * cell_to_gate_weights,const int32_t cell_to_gate_scale_a,const int32_t cell_to_gate_scale_b,const int16_t * layer_norm_coefficients,const int32_t * layer_norm_bias,const int32_t layer_norm_input_scale_a,const int32_t layer_norm_input_scale_b,const int32_t layer_norm_variance_guard,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,CpuBackendContext * context,int32_t * scratch5)487 void CalculateLstmGateInteger8x8_16(
488     // Input and weights
489     const int8_t* input, const int8_t* input_to_gate_weights,
490     const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
491     const int32_t input_to_gate_scale_b,
492     // Output state and weights
493     const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
494     const int32_t* recurrent_to_gate_bias,
495     const int32_t recurrent_to_gate_scale_a,
496     const int32_t recurrent_to_gate_scale_b,
497     // Cell state and weights
498     const int16_t* cell_state, const int16_t* cell_to_gate_weights,
499     const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
500     // Layer normalization parameters (layer norm LSTM)
501     const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
502     const int32_t layer_norm_input_scale_a,
503     const int32_t layer_norm_input_scale_b,
504     const int32_t layer_norm_variance_guard,
505     // Array sizes
506     const int n_batch, const int n_input, const int n_output, const int n_cell,
507     const TfLiteFusedActivation activation,
508     // Output
509     int16_t* gate,
510     // Parameters for performance optimizations
511     CpuBackendContext* context,
512     // Scratch arrays
513     int32_t* scratch5) {
514   const bool use_peephole = (cell_to_gate_weights != nullptr);
515   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
516 
517   // Initialize scratch buffers with zeros. Note that unlike float and hybrid
518   // versions, bias is only used in layer normalization.
519   std::fill_n(gate, n_batch * n_cell, 0);
520   // For each batch and cell: compute input_weight * input.
521   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
522       input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
523       input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
524       context);
525   // Note: no aux_input.
526 
527   // For each batch and cell: compute recurrent_weight * output_state.
528   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
529       output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
530       recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
531       n_cell, 0, scratch5, gate, context);
532   // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
533   if (use_peephole) {
534     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
535         cell_to_gate_weights, n_output, cell_state, n_batch,
536         cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
537   }
538   // Do layer normalization (if layer norm LSTM)
539   if (use_layer_norm) {
540     tensor_utils::ApplyLayerNorm(
541         gate, layer_norm_coefficients, layer_norm_bias,
542         layer_norm_input_scale_a, layer_norm_input_scale_b,
543         layer_norm_variance_guard, n_batch, n_cell, gate);
544   }
545   // Apply activation
546   switch (activation) {
547     case kTfLiteActSigmoid:
548       tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
549       break;
550     case kTfLiteActTanh:
551       tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
552       break;
553     default:
554       // Only Sigmoid or Tanh is used.
555       TFLITE_ASSERT_FALSE;
556   }
557 }
558 
559 // Updates the LSTM cell state, used by both integer LSTM versions.
560 // Also see UpdateLstmCellFloat.
561 //
562 // Parameters:
563 //  - n_batch, n_cell: sizes of vectors
564 //  - cell_state: input/output vector, size n_batch*n_cell
565 //  - cell_state_scale: scaling factor of cell state.
566 //  - input_gate: input vector, size n_batch*n_cell.
567 //  - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
568 //  - cell_gate: input vector, size n_batch*n_cell.
569 //  - use_cifg: use 1-forget_gate instead of input_gate.
570 //  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellInteger(int n_batch,int n_cell,int16_t * cell_state,int32_t cell_state_scale,const int16_t * input_gate,int16_t * forget_gate,const int16_t * cell_gate,bool use_cifg,int16_t clip)571 void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
572                            int32_t cell_state_scale, const int16_t* input_gate,
573                            int16_t* forget_gate, const int16_t* cell_gate,
574                            bool use_cifg, int16_t clip) {
575   // Use the forget_gate array as scratch, as input_gate array is not allocated
576   // in CIFG case. (Be careful not to write to the scratch before reading the
577   // forget gate data.)
578   int16_t* scratch = forget_gate;
579 
580   tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
581                          cell_state);
582   if (use_cifg) {
583     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
584     tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
585                            30 + cell_state_scale, scratch);
586   } else {
587     tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
588                            30 + cell_state_scale, scratch);
589   }
590   tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, cell_state);
591 
592   if (clip > 0) {
593     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
594   }
595 }
596 
597 // Calculates the output state tensor of an LSTM step. See Float and hybrid
598 // versions as well.
599 //
600 // Parameters:
601 //  - n_batch: batches: the number of distinct vectors in each array.
602 //  - n_cell, n_output: sizes of vectors.
603 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
604 //  - cell_state_scale: scaling of cell_state.
605 //  - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
606 //  - hidden_zp: zero_point for cell_state.*output_gate
607 //  - projection_weights, proj_scale_[a|b], projection_bias:
608 //      constant inputs, describing projection matrix and bias.
609 //  - output_state_zp: zero point of output_state. (Input, calibrated value.)
610 //  - quantized_proj_clip: if > 0, clip the output of the projection.
611 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
612 //  - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
613 //  - scratch0: scratch area of size n_batch*n_cell
614 //  - scratch1: scratch area of size n_batch*n_cell
615 //  - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputInteger8x8_16(int n_batch,int n_cell,int n_output,const int16_t * cell_state,int32_t cell_state_scale,const int16_t * output_gate,int32_t hidden_scale_a,int32_t hidden_scale_b,int32_t hidden_zp,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int8_t quantized_proj_clip,int8_t * output_state,CpuBackendContext * context,int16_t * scratch0,int8_t * scratch1,int32_t * scratch2)616 void CalculateLstmOutputInteger8x8_16(
617     int n_batch, int n_cell, int n_output, const int16_t* cell_state,
618     int32_t cell_state_scale, const int16_t* output_gate,
619     int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
620     const int8_t* projection_weights, int32_t proj_scale_a,
621     int32_t proj_scale_b, const int32_t* projection_bias,
622     int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
623     CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
624     int32_t* scratch2) {
625   // Note: unlike float/hybrid, the activation is always Tanh.
626   tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell,
627                           scratch0);
628   tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b,
629                          n_batch, n_cell, hidden_zp, scratch1);
630 
631   const bool use_projection = (projection_weights != nullptr);
632 
633   if (use_projection) {
634     // Note: no bias like in float/hybrid
635     std::fill_n(output_state, n_batch * n_output, 0);
636     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
637         scratch1, projection_bias, projection_weights, proj_scale_a,
638         proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
639         output_state, context);
640     if (quantized_proj_clip > 0) {
641       tensor_utils::CwiseClipping(output_state, n_batch * n_output,
642                                   quantized_proj_clip);
643     }
644   } else {
645     std::copy_n(scratch1, n_batch * n_output, output_state);
646   }
647 }
648 
649 // Calculates a single LSTM gate, int8x8_8 version.
650 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_8(const int8_t * input,int32_t input_zp,const int8_t * input_to_gate_weight,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int32_t input_times_weights_scale_a,const int32_t input_times_weights_scale_b,const int32_t input_times_weights_zp,const int8_t * output_state,const int32_t output_state_zp,const int8_t * recurrent_to_gate_weight,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int32_t output_state_times_weights_scale_a,const int32_t output_state_times_weights_scale_b,const int32_t output_state_times_weights_zp,const int16_t * layer_norm_gate_weight,const int32_t layer_norm_gate_scale_a,const int32_t layer_norm_gate_scale_b,const int32_t * gate_bias,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,int8_t * scratch0,int8_t * scratch1)651 void CalculateLstmGateInteger8x8_8(
652     // Inputs and weights
653     const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
654     const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
655     const int32_t input_times_weights_scale_a,
656     const int32_t input_times_weights_scale_b,
657     const int32_t input_times_weights_zp,
658     // Output state and weights
659     const int8_t* output_state, const int32_t output_state_zp,
660     const int8_t* recurrent_to_gate_weight,
661     const int32_t recurrent_to_gate_scale_a,
662     const int32_t recurrent_to_gate_scale_b,
663     const int32_t output_state_times_weights_scale_a,
664     const int32_t output_state_times_weights_scale_b,
665     const int32_t output_state_times_weights_zp,
666     // Layer normalization parameters (layer norm LSTM)
667     const int16_t* layer_norm_gate_weight,
668     const int32_t layer_norm_gate_scale_a,
669     const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
670     // Array sizes
671     const int n_batch, const int n_input, const int n_output, const int n_cell,
672     const TfLiteFusedActivation activation,
673     // Output
674     int16_t* gate,
675     // Scratch arrays, both sized n_batch*n_cell
676     int8_t* scratch0, int8_t* scratch1) {
677   // Multiply input * input_weights => scratch0
678   tensor_utils::MatrixBatchVectorMultiply(
679       input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
680       input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
681       input_times_weights_zp);
682   // Multiply output_state * recurrent_weights => scratch1
683   tensor_utils::MatrixBatchVectorMultiply(
684       output_state, output_state_zp, recurrent_to_gate_weight,
685       recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
686       n_cell, scratch1, output_state_times_weights_zp);
687   // Add scratch0 + scratch1 => gate
688   tensor_utils::TwoGateSaturatingAdd(
689       scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
690       input_times_weights_scale_a, input_times_weights_scale_b,
691       output_state_times_weights_scale_a, output_state_times_weights_scale_b,
692       n_batch, n_cell, gate);
693   // Apply layer normalization.
694   tensor_utils::ApplyLayerNormFloat(
695       gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
696       layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
697   // Apply activation.
698   switch (activation) {
699     case kTfLiteActSigmoid:
700       tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
701       break;
702     case kTfLiteActTanh:
703       tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
704       break;
705     default:
706       // Only Sigmoid or Tanh is used.
707       TFLITE_ASSERT_FALSE;
708   }
709 }
710 
711 // Calculates the output state tensor of an LSTM step. See Float and hybrid
712 // versions as well.
713 //
714 // Parameters:
715 //  - n_batch: batches: the number of distinct vectors in each array.
716 //  - n_cell, n_output: sizes of vectors.
717 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
718 //  - projection_weights, proj_scale_[a|b], projection_bias:
719 //      constant inputs, describing projection matrix and bias.
720 //  - output_state_zp: zero point of the output state.
721 //  - quantized_proj_clip: if > 0, clip the output of the projection.
722 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
723 //  - scratch: scratch area of size n_batch*n_cell
CalculateLstmOutputInteger8x8_8(int n_batch,int n_cell,int n_output,const int16_t * cell_state,const int16_t * output_gate,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int32_t quantized_proj_clip,int8_t * output_state,int16_t * scratch)724 void CalculateLstmOutputInteger8x8_8(
725     int n_batch, int n_cell, int n_output, const int16_t* cell_state,
726     const int16_t* output_gate, const int8_t* projection_weights,
727     int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
728     int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
729     int16_t* scratch) {
730   // Note: unlike float/hybrid, the activation is always Tanh.
731   tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
732   tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15,
733                          scratch);
734   // Note: no bias like in float/hybrid
735   tensor_utils::MatrixBatchVectorMultiply(
736       scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
737       n_batch, n_cell, n_output, output_state_zp, output_state);
738   if (quantized_proj_clip > 0) {
739     tensor_utils::CwiseClipping(output_state, n_batch * n_output,
740                                 quantized_proj_clip);
741   }
742 }
743 
744 // Performs an LSTM batch inference step for input specified by input_ptr.
745 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
746 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
747 // parameters:
748 //  - params: various LSTM params including activation, clipping, etc.,
749 //  - n_batch: size of batch,
750 //  - n_cell: number of cells (or units),
751 //  - n_input: the input size,
752 //  - n_aux_input: the auxiliary input size.
753 //  - n_output: the output size.
754 //  - output_batch_leading_dim: the leading dimension of the output buffer.
755 //
756 // Input of size 'n_batch * n_input':
757 //   input_ptr
758 // Input of size 'n_batch * n_aux_input':
759 //   aux_input_ptr                     - optional (can be nullptr)
760 //
761 // LSTM weights:
762 // Input weights of size 'n_cell * n_input':
763 //   input_to_input_weights            - optional
764 //   input_to_forget_weights
765 //   input_to_cell_weights
766 //   input_to_output_weights
767 // Auxiliary input weights of size 'n_cell * n_aux_input':
768 //   aux_input_to_input_weights        - optional
769 //   aux_input_to_forget_weights       - optional
770 //   aux_input_to_cell_weights         - optional
771 //   aux_input_to_output_weights       - optional
772 // Recurrent weights of size 'n_cell * n_output':
773 //   recurrent_to_input_weights        - optional
774 //   recurrent_to_forget_weights
775 //   recurrent_to_cell_weights
776 //   recurrent_to_input_weights
777 // Peephole weights of size 'n_cell', representing diagonal matrices.
778 //   cell_to_input_weights             - optional
779 //   cell_to_cell_weights              - optional
780 //   cell_to_output_weights            - optional
781 // Projection weights of size 'n_output * n_cell'
782 //   projection_weights_ptr            - optional
783 // Gate biases of size 'n_cell':
784 //   input_gate_bias_ptr               - optional
785 //   forget_gate_bias_ptr
786 //   cell_gate_bias_ptr
787 //   output_gate_bias_ptr
788 //
789 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
790 //   input_layer_norm_coefficients_ptr  - optional
791 //   forget_layer_norm_coefficients_ptr - optional
792 //   cell_layer_norm_coefficients_ptr   - optional
793 //   output_layer_norm_coefficients_ptr - optional
794 //
795 // The pointers to the cell and output state and the output are updated.
796 //
797 // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
798 // in batch_major order, and each step processes batch_size many inputs from
799 // input_ptr, and updates batch_size many cell and output states.
800 //
801 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
802 // output tensor, and in most cases will be equal to n_output. It is usually not
803 // when we want to store the LSTM output into a slice of the output tensor, e.g.
804 // for bidirectional LSTMs with merge_outputs. In this case, the batched
805 // operations cannot be used since they assume that the batched outputs are
806 // contiguous, and we manually loop over the batched outputs.
807 // LINT.IfChange
LstmStepFloat(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr)808 inline void LstmStepFloat(
809     const float* input_ptr, const float* input_to_input_weights_ptr,
810     const float* input_to_forget_weights_ptr,
811     const float* input_to_cell_weights_ptr,
812     const float* input_to_output_weights_ptr, const float* aux_input_ptr,
813     const float* aux_input_to_input_weights_ptr,
814     const float* aux_input_to_forget_weights_ptr,
815     const float* aux_input_to_cell_weights_ptr,
816     const float* aux_input_to_output_weights_ptr,
817     const float* recurrent_to_input_weights_ptr,
818     const float* recurrent_to_forget_weights_ptr,
819     const float* recurrent_to_cell_weights_ptr,
820     const float* recurrent_to_output_weights_ptr,
821     const float* cell_to_input_weights_ptr,
822     const float* cell_to_forget_weights_ptr,
823     const float* cell_to_output_weights_ptr,
824     const float* input_layer_norm_coefficients_ptr,
825     const float* forget_layer_norm_coefficients_ptr,
826     const float* cell_layer_norm_coefficients_ptr,
827     const float* output_layer_norm_coefficients_ptr,
828     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
829     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
830     const float* projection_weights_ptr, const float* projection_bias_ptr,
831     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
832     int n_aux_input, int n_output, int output_batch_leading_dim,
833     float* output_state_ptr, float* cell_state_ptr, float* scratch0,
834     float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
835   ruy::profiler::ScopeLabel label("LstmStepFloat");
836   // Since we have already checked that weights are all there or none, we can
837   // check the existence of only one to the get the condition.
838   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
839 
840   // Make named scratch buffers.
841   float* input_gate_scratch = scratch0;
842   float* forget_gate_scratch = scratch1;
843   float* cell_gate_scratch = scratch2;
844   float* output_gate_scratch = scratch3;
845 
846   // Check if inputs are all zeros so we can skip some computations.
847   const bool is_input_all_zeros =
848       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
849   const bool is_aux_input_all_zeros =
850       (aux_input_ptr == nullptr ||
851        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
852   if (!use_cifg) {
853     // Calculate the input gate. (If not CIFG.)
854     CalculateLstmGateFloat(
855         input_ptr, input_to_input_weights_ptr, aux_input_ptr,
856         aux_input_to_input_weights_ptr, output_state_ptr,
857         recurrent_to_input_weights_ptr, cell_state_ptr,
858         cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
859         input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
860         /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
861         is_input_all_zeros, is_aux_input_all_zeros);
862   }
863   // Calculate the forget gate.
864   CalculateLstmGateFloat(
865       input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
866       aux_input_to_forget_weights_ptr, output_state_ptr,
867       recurrent_to_forget_weights_ptr, cell_state_ptr,
868       cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
869       forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
870       /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
871       is_aux_input_all_zeros);
872   // Calculate the cell update gate.
873   CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
874                          aux_input_to_cell_weights_ptr, output_state_ptr,
875                          recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
876                          /*cell_to_gate_weights=*/nullptr,
877                          cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
878                          n_batch, n_input, n_aux_input, n_output, n_cell,
879                          params->activation, cell_gate_scratch,
880                          is_input_all_zeros, is_aux_input_all_zeros);
881   // Update the cell state.
882   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
883                       forget_gate_scratch, cell_gate_scratch, use_cifg,
884                       params->cell_clip);
885   // Calculate output gate.
886   CalculateLstmGateFloat(
887       input_ptr, input_to_output_weights_ptr, aux_input_ptr,
888       aux_input_to_output_weights_ptr, output_state_ptr,
889       recurrent_to_output_weights_ptr, cell_state_ptr,
890       cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
891       output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
892       /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
893       is_aux_input_all_zeros);
894   // Update the output state.
895   CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
896                            output_gate_scratch, params->activation,
897                            projection_weights_ptr, projection_bias_ptr,
898                            params->proj_clip, output_state_ptr, scratch2);
899   // Copy output state to the output. Note that the output's rows may not be
900   // contiguous (output_batch_leading_dim != n_output).
901   for (int b = 0; b < n_batch; b++) {
902     std::copy_n(output_state_ptr + b * n_output, n_output,
903                 output_ptr + b * output_batch_leading_dim);
904   }
905 }
906 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
907 //                 ../experimental/kernels/fp16/lstm_eval.cc)
908 
909 // Same as above but with quantized weight matrices. In detail:
910 // Input of size 'n_batch * n_input':
911 //   input_ptr
912 // Input of size 'n_batch * n_aux_input':
913 //   aux_input_ptr                     - optional (can be nullptr)
914 //
915 // LSTM weights:
916 // Quantized input weights of size 'n_cell * n_input':
917 //   input_to_input_weights            - optional
918 //   input_to_forget_weights
919 //   input_to_cell_weights
920 //   input_to_input_weights
921 // Quantized auxiliary input weights of size 'n_cell * n_aux_input':
922 //   aux_input_to_input_weights        - optional
923 //   aux_input_to_forget_weights       - optional
924 //   aux_input_to_cell_weights         - optional
925 //   aux_input_to_output_weights       - optional
926 // Quantized recurrent weights of size 'n_cell * n_output':
927 //   recurrent_to_input_weights        - optional
928 //   recurrent_to_forget_weights
929 //   recurrent_to_cell_weights
930 //   recurrent_to_input_weights
931 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
932 //   cell_to_input_weights             - optional
933 //   cell_to_cell_weights              - optional
934 //   cell_to_output_weights            - optional
935 // Quantized projection weights of size 'n_output * n_cell'
936 //   projection_weights_ptr            - optional
937 // Weight scales (scalars) for each of the weights above.
938 //   input_to_input_weights_scale      - optional
939 //   input_to_forget_weights_scale
940 //   input_to_cell_weights_scale
941 //   input_to_output_weights_scale
942 //   aux_input_to_input_weights_scale  - optional
943 //   aux_input_to_forget_weights_scale - optional
944 //   aux_input_to_cell_weights_scale   - optional
945 //   aux_input_to_output_weights_scale - optional
946 //   recurrent_to_input_weights_scale  - optional
947 //   recurrent_to_forget_weights_scale
948 //   recurrent_to_cell_weights_scale
949 //   recurrent_to_output_weights_scale
950 //   cell_to_input_weights_scale,
951 //   cell_to_forget_weights_scale,
952 //   cell_to_output_weights_scale,
953 //   projection_weights_scale          - optional
954 // Gate biases of size 'n_cell':
955 //   input_gate_bias_ptr               - optional
956 //   forget_gate_bias_ptr
957 //   cell_gate_bias_ptr
958 //   output_gate_bias_ptr
959 //
960 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
961 //   input_layer_norm_coefficients_ptr  - optional
962 //   forget_layer_norm_coefficients_ptr - optional
963 //   cell_layer_norm_coefficients_ptr   - optional
964 //   output_layer_norm_coefficients_ptr - optional
965 //
966 // Temporary pre-allocated storage for quantized values:
967 //   quantized_input_ptr (same size as input_ptr)
968 //   quantized_output_state_ptr (same size as output_state_ptr)
969 //   quantized_output_scratch (same size as cell_state_ptr)
970 // Temporary pre-allocated storage for recovered values:
971 //   recovered_cell_weights (same size as cell_to_*_weights)
972 //
973 // Outputs:
974 //   output_state_ptr - size 'n_batch * n_output'
975 //   cell_state_ptr   - size 'n_batch * n_cell'
976 //   output_ptr       - size 'n_batch * output_batch_leading_dim'
LstmStepHybrid(const float * input_ptr,const int8_t * input_to_input_weights_ptr,const uint8_t * input_to_input_weights_ledger_ptr,float input_to_input_weights_scale,const int8_t * input_to_forget_weights_ptr,const uint8_t * input_to_forget_weights_ledger_ptr,float input_to_forget_weights_scale,const int8_t * input_to_cell_weights_ptr,const uint8_t * input_to_cell_weights_ledger_ptr,float input_to_cell_weights_scale,const int8_t * input_to_output_weights_ptr,const uint8_t * input_to_output_weights_ledger_ptr,float input_to_output_weights_scale,const float * aux_input_ptr,const int8_t * aux_input_to_input_weights_ptr,float aux_input_to_input_weights_scale,const int8_t * aux_input_to_forget_weights_ptr,float aux_input_to_forget_weights_scale,const int8_t * aux_input_to_cell_weights_ptr,float aux_input_to_cell_weights_scale,const int8_t * aux_input_to_output_weights_ptr,float aux_input_to_output_weights_scale,const int8_t * recurrent_to_input_weights_ptr,const uint8_t * recurrent_to_input_weights_ledger_ptr,float recurrent_to_input_weights_scale,const int8_t * recurrent_to_forget_weights_ptr,const uint8_t * recurrent_to_forget_weights_ledger_ptr,float recurrent_to_forget_weights_scale,const int8_t * recurrent_to_cell_weights_ptr,const uint8_t * recurrent_to_cell_weights_ledger_ptr,float recurrent_to_cell_weights_scale,const int8_t * recurrent_to_output_weights_ptr,const uint8_t * recurrent_to_output_weights_ledger_ptr,float recurrent_to_output_weights_scale,const int8_t * cell_to_input_weights_ptr,float cell_to_input_weights_scale,const int8_t * cell_to_forget_weights_ptr,float cell_to_forget_weights_scale,const int8_t * cell_to_output_weights_ptr,float cell_to_output_weights_scale,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const int8_t * projection_weights_ptr,const uint8_t * projection_weights_ledger_ptr,float projection_weights_scale,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * input_sf,float * aux_input_sf,float * output_state_sf,float * scaling_factors_scratch,float * recovered_cell_weights,int8_t * quantized_input_ptr,int8_t * quantized_aux_input_ptr,int8_t * quantized_output_state_ptr,int8_t * quantized_output_scratch,float * output_state_ptr,float * cell_state_ptr,int32_t * accum_scratch_ptr,float * output_ptr,int32_t * input_zp,int32_t * aux_input_zp,int32_t * output_state_zp,int32_t * row_sums,int row_sums_size,bool * compute_row_sums,bool asymmetric_quantize_inputs,CpuBackendContext * context)977 inline void LstmStepHybrid(
978     const float* input_ptr, const int8_t* input_to_input_weights_ptr,
979     const uint8_t* input_to_input_weights_ledger_ptr,
980     float input_to_input_weights_scale,
981     const int8_t* input_to_forget_weights_ptr,
982     const uint8_t* input_to_forget_weights_ledger_ptr,
983     float input_to_forget_weights_scale,
984     const int8_t* input_to_cell_weights_ptr,
985     const uint8_t* input_to_cell_weights_ledger_ptr,
986     float input_to_cell_weights_scale,
987     const int8_t* input_to_output_weights_ptr,
988     const uint8_t* input_to_output_weights_ledger_ptr,
989     float input_to_output_weights_scale, const float* aux_input_ptr,
990     const int8_t* aux_input_to_input_weights_ptr,
991     float aux_input_to_input_weights_scale,
992     const int8_t* aux_input_to_forget_weights_ptr,
993     float aux_input_to_forget_weights_scale,
994     const int8_t* aux_input_to_cell_weights_ptr,
995     float aux_input_to_cell_weights_scale,
996     const int8_t* aux_input_to_output_weights_ptr,
997     float aux_input_to_output_weights_scale,
998     const int8_t* recurrent_to_input_weights_ptr,
999     const uint8_t* recurrent_to_input_weights_ledger_ptr,
1000     float recurrent_to_input_weights_scale,
1001     const int8_t* recurrent_to_forget_weights_ptr,
1002     const uint8_t* recurrent_to_forget_weights_ledger_ptr,
1003     float recurrent_to_forget_weights_scale,
1004     const int8_t* recurrent_to_cell_weights_ptr,
1005     const uint8_t* recurrent_to_cell_weights_ledger_ptr,
1006     float recurrent_to_cell_weights_scale,
1007     const int8_t* recurrent_to_output_weights_ptr,
1008     const uint8_t* recurrent_to_output_weights_ledger_ptr,
1009     float recurrent_to_output_weights_scale,
1010     const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
1011     const int8_t* cell_to_forget_weights_ptr,
1012     float cell_to_forget_weights_scale,
1013     const int8_t* cell_to_output_weights_ptr,
1014     float cell_to_output_weights_scale,
1015     const float* input_layer_norm_coefficients_ptr,
1016     const float* forget_layer_norm_coefficients_ptr,
1017     const float* cell_layer_norm_coefficients_ptr,
1018     const float* output_layer_norm_coefficients_ptr,
1019     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
1020     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
1021     const int8_t* projection_weights_ptr,
1022     const uint8_t* projection_weights_ledger_ptr,
1023     float projection_weights_scale, const float* projection_bias_ptr,
1024     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
1025     int n_aux_input, int n_output, int output_batch_leading_dim,
1026     float* scratch0, float* scratch1, float* scratch2, float* scratch3,
1027     float* input_sf, float* aux_input_sf, float* output_state_sf,
1028     float* scaling_factors_scratch, float* recovered_cell_weights,
1029     int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
1030     int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
1031     float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
1032     float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
1033     int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
1034     bool* compute_row_sums, bool asymmetric_quantize_inputs,
1035     CpuBackendContext* context) {
1036   ruy::profiler::ScopeLabel label("LstmStepHybrid");
1037   // Since we have already checked that weights are all there or none, we
1038   // can check the existence of only one to the get the condition.
1039   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
1040   // Make named scratch buffers for the different gates.
1041   float* input_gate_scratch = scratch0;
1042   float* forget_gate_scratch = scratch1;
1043   float* cell_gate_scratch = scratch2;
1044   float* output_gate_scratch = scratch3;
1045 
1046   int32_t* input_to_input_row_sums = nullptr;
1047   int32_t* input_to_forget_row_sums = nullptr;
1048   int32_t* input_to_cell_row_sums = nullptr;
1049   int32_t* input_to_output_row_sums = nullptr;
1050   int32_t* aux_input_to_input_row_sums = nullptr;
1051   int32_t* aux_input_to_forget_row_sums = nullptr;
1052   int32_t* aux_input_to_cell_row_sums = nullptr;
1053   int32_t* aux_input_to_output_row_sums = nullptr;
1054   int32_t* recurrent_to_input_row_sums = nullptr;
1055   int32_t* recurrent_to_forget_row_sums = nullptr;
1056   int32_t* recurrent_to_cell_row_sums = nullptr;
1057   int32_t* recurrent_to_output_row_sums = nullptr;
1058   int32_t* projection_weights_row_sums = nullptr;
1059 
1060   if (asymmetric_quantize_inputs) {
1061     int num_row_sums = use_cifg ? 6 : 8;
1062     if (aux_input_ptr != nullptr) {
1063       num_row_sums += use_cifg ? 3 : 4;
1064     }
1065     if (projection_weights_ptr != nullptr) {
1066       num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
1067     }
1068     TF_LITE_ASSERT(row_sums_size == num_row_sums);
1069     input_to_input_row_sums = row_sums;
1070     input_to_forget_row_sums =
1071         use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
1072     input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
1073     input_to_output_row_sums = input_to_cell_row_sums + n_cell;
1074     if (aux_input_ptr != nullptr) {
1075       aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
1076       aux_input_to_forget_row_sums = use_cifg
1077                                          ? aux_input_to_input_row_sums
1078                                          : aux_input_to_input_row_sums + n_cell;
1079       aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
1080       aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
1081     }
1082     recurrent_to_input_row_sums = aux_input_ptr
1083                                       ? aux_input_to_output_row_sums + n_cell
1084                                       : input_to_output_row_sums + n_cell;
1085     recurrent_to_forget_row_sums = use_cifg
1086                                        ? recurrent_to_input_row_sums
1087                                        : recurrent_to_input_row_sums + n_cell;
1088     recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
1089     recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
1090     if (projection_weights_ptr != nullptr) {
1091       projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
1092     }
1093     if (*compute_row_sums) {
1094       ComputeRowSums(
1095           input_to_input_row_sums, input_to_forget_row_sums,
1096           input_to_cell_row_sums, input_to_output_row_sums,
1097           aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
1098           aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
1099           recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
1100           recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
1101           projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
1102           n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
1103           input_to_cell_weights_ptr, input_to_output_weights_ptr,
1104           aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1105           aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1106           recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
1107           recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
1108           projection_weights_ptr, use_cifg, aux_input_ptr);
1109       *compute_row_sums = false;
1110     }
1111   }
1112 
1113   // Check if inputs are all zeros so we can skip some computations.
1114   const bool is_input_all_zeros =
1115       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
1116   const bool is_aux_input_all_zeros =
1117       (aux_input_ptr == nullptr ||
1118        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
1119   const bool is_output_state_all_zeros =
1120       tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
1121   // Quantize inputs.
1122   if (!is_input_all_zeros) {
1123     tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
1124                                       quantized_input_ptr, input_sf, input_zp,
1125                                       asymmetric_quantize_inputs);
1126   }
1127   if (!is_aux_input_all_zeros) {
1128     tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
1129                                       quantized_aux_input_ptr, aux_input_sf,
1130                                       aux_input_zp, asymmetric_quantize_inputs);
1131   }
1132   if (!is_output_state_all_zeros) {
1133     tensor_utils::BatchQuantizeFloats(
1134         output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
1135         output_state_sf, output_state_zp, asymmetric_quantize_inputs);
1136   }
1137   if (!use_cifg) {
1138     // Calculate the input gate. (If not CIFG.)
1139     CalculateLstmGateHybrid(
1140         quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
1141         input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
1142         input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
1143         aux_input_zp, aux_input_to_input_weights_ptr,
1144         aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
1145         quantized_output_state_ptr, output_state_sf, output_state_zp,
1146         recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
1147         recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
1148         cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
1149         input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
1150         n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1151         input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1152         is_output_state_all_zeros, compute_row_sums, context,
1153         scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1154   }
1155   // Calculate the forget gate.
1156   CalculateLstmGateHybrid(
1157       quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
1158       input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
1159       input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
1160       aux_input_zp, aux_input_to_forget_weights_ptr,
1161       aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
1162       quantized_output_state_ptr, output_state_sf, output_state_zp,
1163       recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
1164       recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
1165       cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1166       forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
1167       n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1168       forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1169       is_output_state_all_zeros, compute_row_sums, context,
1170       scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1171   // Calculate the cell update gate.
1172   CalculateLstmGateHybrid(
1173       quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
1174       input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
1175       input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
1176       aux_input_zp, aux_input_to_cell_weights_ptr,
1177       aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
1178       quantized_output_state_ptr, output_state_sf, output_state_zp,
1179       recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
1180       recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
1181       /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
1182       /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
1183       cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
1184       params->activation, cell_gate_scratch, is_input_all_zeros,
1185       is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
1186       context, scaling_factors_scratch, recovered_cell_weights,
1187       accum_scratch_ptr);
1188   // Update the cell state.
1189   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
1190                       forget_gate_scratch, cell_gate_scratch, use_cifg,
1191                       params->cell_clip);
1192   // Calculate the output gate.
1193   CalculateLstmGateHybrid(
1194       quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
1195       input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
1196       input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
1197       aux_input_zp, aux_input_to_output_weights_ptr,
1198       aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
1199       quantized_output_state_ptr, output_state_sf, output_state_zp,
1200       recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
1201       recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
1202       cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
1203       output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
1204       n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1205       output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1206       is_output_state_all_zeros, compute_row_sums, context,
1207       scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1208   // Update the output state.
1209   CalculateLstmOutputHybrid(
1210       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1211       params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
1212       projection_weights_scale, projection_bias_ptr, params->proj_clip,
1213       output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
1214       compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
1215       input_zp, accum_scratch_ptr);
1216   // Copy output state to the output. Note that the output's rows may not be
1217   // contiguous (output_batch_leading_dim != n_output).
1218   for (int b = 0; b < n_batch; b++) {
1219     std::copy_n(output_state_ptr + b * n_output, n_output,
1220                 output_ptr + b * output_batch_leading_dim);
1221   }
1222 }
1223 
1224 // Fully quantized lstm kernel for 16 bit gate matmul output.
1225 //
1226 // Input tensor of size n_batch * n_input:
1227 //   input_ptr
1228 //
1229 // LSTM weights:
1230 // Quantized input weights of size 'n_cell * n_input':
1231 //   input_to_input_weight_ptr            - optional
1232 //   input_to_forget_weight_ptr           - optional
1233 //   input_to_cell_weight_ptr             - optional
1234 //   input_to_output_weight_ptr           - optional
1235 //
1236 // Quantized recurrent weights of size 'n_cell * n_output':
1237 //   recurrent_to_input_weight_ptr        - optional
1238 //   recurrent_to_forget_weights_ptr
1239 //   recurrent_to_cell_weights_ptr
1240 //   recurrent_to_input_weights_ptr
1241 //
1242 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1243 //   cell_to_input_weights               - optional
1244 //   cell_to_cell_weights                - optional
1245 //   cell_to_output_weights              - optional
1246 //
1247 // Quantized projection weights of size 'n_output * n_cell'
1248 //   projection_weight_ptr                     - optional
1249 //
1250 // Weight scales (scalars) for each of the weights above.
1251 //   effective_input_to_input_scale_a    - optional
1252 //   effective_input_to_input_scale_b    - optional
1253 //   effective_input_to_forget_scale_a
1254 //   effective_input_to_forget_scale_b
1255 //   effective_input_to_cell_scale_a
1256 //   effective_input_to_cell_scale_b
1257 //   effective_input_to_output_scale_a
1258 //   effective_input_to_output_scale_b
1259 //   effective_recurrent_to_input_scale_a    - optional
1260 //   effective_recurrent_to_input_scale_b    - optional
1261 //   effective_recurrent_to_forget_scale_a
1262 //   effective_recurrent_to_forget_scale_b
1263 //   effective_recurrent_to_cell_scale_a
1264 //   effective_recurrent_to_cell_scale_b
1265 //   effective_recurrent_to_output_scale_a
1266 //   effective_recurrent_to_output_scale_b
1267 //   effective_proj_scale_a                  - optional
1268 //   effective_proj_scale_b                  - optional
1269 //
1270 // Gate biases of size 'n_cell':
1271 //   input_gate_bias_ptr                 - optional
1272 //   forget_gate_bias_ptr
1273 //   cell_gate_bias_ptr
1274 //   output_gate_bias_ptr
1275 //
1276 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1277 //   layer_norm_input_weight_ptr    - optional
1278 //   layer_norm_forget_weight_ptr   - optional
1279 //   layer_norm_cell_weight_ptr     - optional
1280 //   layer_norm_output_weight_ptr   - optional
1281 //
1282 // Layer norm scales of size 'n_cell'.
1283 //   layer_norm_input_scale_a     - optional
1284 //   layer_norm_input_scale_b     - optional
1285 //   layer_norm_forget_scale_a    - optional
1286 //   layer_norm_forget_scale_b    - optional
1287 //   layer_norm_cell_scale_a      - optional
1288 //   layer_norm_cell_scale_b      - optional
1289 //   layer_norm_output_scale_a    - optional
1290 //   layer_norm_output_scale_b    - optional
1291 //
1292 // Scalar values:
1293 //   quantized_cell_clip: quantized clip value for cell.
1294 //   quantized_proj_clip: quantized clip value for projection.
1295 //   cell_state_scale: the power of two scale for cell state.
1296 //
1297 // Zero points:
1298 //   output_state_zp: zero point of output state
1299 //   hidden_zp: zero point for hidden state.
1300 //
1301 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1302 // n_batch.
1303 //   scratch0
1304 //   scratch1
1305 //   scratch2
1306 //   scratch3
1307 //   scratch4
1308 //   scratch5: this scratch buffer is created purely for optimizing the
1309 //              MatrixBatchVectorMultiplyAccumulate.
1310 //
1311 // Outputs:
1312 //   output_state_ptr - size 'n_batch * n_output'
1313 //   cell_state_ptr   - size 'n_batch * n_cell'
1314 //   output_ptr       - size 'n_batch * n_output'
1315 // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
LstmStepInteger8x8_16(const int8_t * input_ptr,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int16_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int16_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int16_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,int32_t hidden_zp,int32_t effective_hidden_scale_a,int32_t effective_hidden_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int32_t cell_state_scale,int32_t input_variance_guard,int32_t forget_variance_guard,int32_t cell_variance_guard,int32_t output_variance_guard,const int32_t * input_to_forget_effective_bias,const int32_t * recurrent_to_forget_effective_bias,const int32_t * input_to_cell_effective_bias,const int32_t * recurrent_to_cell_effective_bias,const int32_t * input_to_output_effective_bias,const int32_t * recurrent_to_output_effective_bias,const int32_t * input_to_input_effective_bias,const int32_t * recurrent_to_input_effective_bias,const int32_t * projection_effective_bias,int n_batch,int n_cell,int n_input,int n_output,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int16_t * scratch0,int16_t * scratch1,int16_t * scratch2,int16_t * scratch3,int8_t * scratch4,int32_t * scratch5,CpuBackendContext * context)1316 inline void LstmStepInteger8x8_16(
1317     const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
1318     int32_t effective_input_to_input_scale_a,
1319     int32_t effective_input_to_input_scale_b,
1320     const int8_t* input_to_forget_weight_ptr,
1321     int32_t effective_input_to_forget_scale_a,
1322     int32_t effective_input_to_forget_scale_b,
1323     const int8_t* input_to_cell_weight_ptr,
1324     int32_t effective_input_to_cell_scale_a,
1325     int32_t effective_input_to_cell_scale_b,
1326     const int8_t* input_to_output_weight_ptr,
1327     int32_t effective_input_to_output_scale_a,
1328     int32_t effective_input_to_output_scale_b,
1329     const int8_t* recurrent_to_input_weight_ptr,
1330     int32_t effective_recurrent_to_input_scale_a,
1331     int32_t effective_recurrent_to_input_scale_b,
1332     const int8_t* recurrent_to_forget_weight_ptr,
1333     int32_t effective_recurrent_to_forget_scale_a,
1334     int32_t effective_recurrent_to_forget_scale_b,
1335     const int8_t* recurrent_to_cell_weight_ptr,
1336     int32_t effective_recurrent_to_cell_scale_a,
1337     int32_t effective_recurrent_to_cell_scale_b,
1338     const int8_t* recurrent_to_output_weight_ptr,
1339     int32_t effective_recurrent_to_output_scale_a,
1340     int32_t effective_recurrent_to_output_scale_b,
1341     const int16_t* cell_to_input_weight_ptr,
1342     int32_t effective_cell_to_input_scale_a,
1343     int32_t effective_cell_to_input_scale_b,
1344     const int16_t* cell_to_forget_weight_ptr,
1345     int32_t effective_cell_to_forget_scale_a,
1346     int32_t effective_cell_to_forget_scale_b,
1347     const int16_t* cell_to_output_weight_ptr,
1348     int32_t effective_cell_to_output_scale_a,
1349     int32_t effective_cell_to_output_scale_b,
1350     const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1351     int32_t effective_proj_scale_b, int32_t hidden_zp,
1352     int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
1353     const int16_t* layer_norm_input_weight_ptr,
1354     int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1355     const int16_t* layer_norm_forget_weight_ptr,
1356     int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1357     const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1358     int32_t layer_norm_cell_scale_b,
1359     const int16_t* layer_norm_output_weight_ptr,
1360     int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1361     const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1362     const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1363     int16_t quantized_cell_clip, int8_t quantized_proj_clip,
1364     int32_t cell_state_scale, int32_t input_variance_guard,
1365     int32_t forget_variance_guard, int32_t cell_variance_guard,
1366     int32_t output_variance_guard,
1367     const int32_t* input_to_forget_effective_bias,
1368     const int32_t* recurrent_to_forget_effective_bias,
1369     const int32_t* input_to_cell_effective_bias,
1370     const int32_t* recurrent_to_cell_effective_bias,
1371     const int32_t* input_to_output_effective_bias,
1372     const int32_t* recurrent_to_output_effective_bias,
1373     const int32_t* input_to_input_effective_bias,
1374     const int32_t* recurrent_to_input_effective_bias,
1375     const int32_t* projection_effective_bias, int n_batch, int n_cell,
1376     int n_input, int n_output, int8_t* output_state_ptr,
1377     int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1378     int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1379     int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
1380   ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
1381   // Make named scratch buffers for the different gates.
1382   int16_t* input_gate_scratch = scratch0;
1383   int16_t* forget_gate_scratch = scratch1;
1384   int16_t* cell_gate_scratch = scratch2;
1385   int16_t* output_gate_scratch = scratch3;
1386 
1387   // Since we have already checked that weights are all there or none, we
1388   // can check the existence of only one to the get the condition.
1389   const bool use_cifg = (input_to_input_weight_ptr == nullptr);
1390 
1391   // Check for nullptrs.
1392   TFLITE_DCHECK(input_to_forget_effective_bias);
1393   TFLITE_DCHECK(recurrent_to_forget_effective_bias);
1394   TFLITE_DCHECK(input_to_cell_effective_bias);
1395   TFLITE_DCHECK(recurrent_to_cell_effective_bias);
1396   TFLITE_DCHECK(input_to_output_effective_bias);
1397   TFLITE_DCHECK(recurrent_to_output_effective_bias);
1398   if (!use_cifg) {
1399     TFLITE_DCHECK(input_to_input_effective_bias);
1400     TFLITE_DCHECK(recurrent_to_input_effective_bias);
1401   }
1402   const bool use_projection = (projection_weight_ptr != nullptr);
1403   if (use_projection) {
1404     TFLITE_DCHECK(projection_effective_bias);
1405   }
1406   if (!use_cifg) {
1407     // Calculate the input gate. (If not CIFG.)
1408     CalculateLstmGateInteger8x8_16(
1409         input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
1410         effective_input_to_input_scale_a, effective_input_to_input_scale_b,
1411         output_state_ptr, recurrent_to_input_weight_ptr,
1412         recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
1413         effective_recurrent_to_input_scale_b, cell_state_ptr,
1414         cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
1415         effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
1416         input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
1417         input_variance_guard, n_batch, n_input, n_output, n_cell,
1418         kTfLiteActSigmoid, input_gate_scratch, context, scratch5);
1419   }
1420   // Calculate the forget gate.
1421   CalculateLstmGateInteger8x8_16(
1422       input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
1423       effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1424       output_state_ptr, recurrent_to_forget_weight_ptr,
1425       recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
1426       effective_recurrent_to_forget_scale_b, cell_state_ptr,
1427       cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
1428       effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
1429       forget_gate_bias_ptr, layer_norm_forget_scale_a,
1430       layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
1431       n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context,
1432       scratch5);
1433   // Calculate the cell update gate.
1434   CalculateLstmGateInteger8x8_16(
1435       input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
1436       effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1437       output_state_ptr, recurrent_to_cell_weight_ptr,
1438       recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
1439       effective_recurrent_to_cell_scale_b, cell_state_ptr,
1440       /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
1441       /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
1442       cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
1443       cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
1444       cell_gate_scratch, context, scratch5);
1445   // Update the cell state.
1446   UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
1447                         input_gate_scratch, forget_gate_scratch,
1448                         cell_gate_scratch, use_cifg, quantized_cell_clip);
1449   // Calculate the output gate.
1450   CalculateLstmGateInteger8x8_16(
1451       input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
1452       effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1453       output_state_ptr, recurrent_to_output_weight_ptr,
1454       recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
1455       effective_recurrent_to_output_scale_b, cell_state_ptr,
1456       cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
1457       effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
1458       output_gate_bias_ptr, layer_norm_output_scale_a,
1459       layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
1460       n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context,
1461       scratch5);
1462   // Update the output state.
1463   CalculateLstmOutputInteger8x8_16(
1464       n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
1465       output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
1466       hidden_zp, projection_weight_ptr, effective_proj_scale_a,
1467       effective_proj_scale_b, projection_effective_bias, output_state_zp,
1468       quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
1469       scratch5);
1470   // Copy output state to the output. Note that unlike float or hybrid, output
1471   // is always contiguous.
1472   std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1473 }
1474 
1475 // Fully quantized lstm kernel for 8 bit gate matmul output.
1476 //
1477 // Input tensor of size n_batch * n_input:
1478 //   input_ptr
1479 //
1480 // LSTM weights:
1481 // Quantized input weights of size 'n_cell * n_input':
1482 //   input_to_input_weight_ptr            - optional
1483 //   input_to_forget_weight_ptr           - optional
1484 //   input_to_cell_weight_ptr             - optional
1485 //   input_to_output_weight_ptr           - optional
1486 //
1487 // Quantized recurrent weights of size 'n_cell * n_output':
1488 //   recurrent_to_input_weight_ptr        - optional
1489 //   recurrent_to_forget_weights_ptr
1490 //   recurrent_to_cell_weights_ptr
1491 //   recurrent_to_input_weights_ptr
1492 //
1493 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1494 //   cell_to_input_weights               - optional
1495 //   cell_to_cell_weights                - optional
1496 //   cell_to_output_weights              - optional
1497 //
1498 // Quantized projection weights of size 'n_output * n_cell'
1499 //   projection_weight_ptr                     - optional
1500 //
1501 // Weight scales (scalars) for each of the weights above.
1502 //   effective_input_to_input_scale_a    - optional
1503 //   effective_input_to_input_scale_b    - optional
1504 //   effective_input_to_forget_scale_a
1505 //   effective_input_to_forget_scale_b
1506 //   effective_input_to_cell_scale_a
1507 //   effective_input_to_cell_scale_b
1508 //   effective_input_to_output_scale_a
1509 //   effective_input_to_output_scale_b
1510 //   effective_recurrent_to_input_scale_a    - optional
1511 //   effective_recurrent_to_input_scale_b    - optional
1512 //   effective_recurrent_to_forget_scale_a
1513 //   effective_recurrent_to_forget_scale_b
1514 //   effective_recurrent_to_cell_scale_a
1515 //   effective_recurrent_to_cell_scale_b
1516 //   effective_recurrent_to_output_scale_a
1517 //   effective_recurrent_to_output_scale_b
1518 //   effective_proj_scale_a                  - optional
1519 //   effective_proj_scale_b                  - optional
1520 //
1521 // Gate biases of size 'n_cell':
1522 //   input_gate_bias_ptr                 - optional
1523 //   forget_gate_bias_ptr
1524 //   cell_gate_bias_ptr
1525 //   output_gate_bias_ptr
1526 //
1527 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1528 //   layer_norm_input_weight_ptr    - optional
1529 //   layer_norm_forget_weight_ptr   - optional
1530 //   layer_norm_cell_weight_ptr     - optional
1531 //   layer_norm_output_weight_ptr   - optional
1532 //
1533 // Layer norm scales of size 'n_cell'.
1534 //   layer_norm_input_scale_a     - optional
1535 //   layer_norm_input_scale_b     - optional
1536 //   layer_norm_forget_scale_a    - optional
1537 //   layer_norm_forget_scale_b    - optional
1538 //   layer_norm_cell_scale_a      - optional
1539 //   layer_norm_cell_scale_b      - optional
1540 //   layer_norm_output_scale_a    - optional
1541 //   layer_norm_output_scale_b    - optional
1542 //
1543 // Scalar values:
1544 //   quantized_cell_clip: quantized clip value for cell.
1545 //   quantized_proj_clip: quantized clip value for projection.
1546 //   cell_state_scale: the power of two scale for cell state.
1547 //
1548 // Zero points:
1549 //   output_state_zp: zero point of output state.
1550 //   hidden_zp: zero point for hidden state.
1551 //
1552 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1553 // n_batch.
1554 //   scratch0
1555 //   scratch1
1556 //   scratch2
1557 //   scratch3
1558 //   scratch4
1559 //   scratch5
1560 //   scratch6
1561 //   scratch7
1562 //
1563 // Outputs:
1564 //   output_state_ptr - size 'n_batch * n_output'
1565 //   cell_state_ptr   - size 'n_batch * n_cell'
1566 //   output_ptr       - size 'n_batch * n_output'
1567 // TODO(b/148688698): Move zero point calculation into Prepare().
1568 // TODO(b/159947023): scratch5 is unused, remove.
LstmStepInteger8x8_8(const int8_t * input_ptr,int32_t input_zp,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int8_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int8_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int8_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,const int32_t * projection_bias_ptr,const TfLiteLSTMParams * params,const int32_t * intermediate_scale_a,const int32_t * intermediate_scale_b,const int32_t * intermediate_zp,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int n_batch,int n_cell,int n_input,int n_output,int output_batch_leading_dim,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int8_t * scratch0,int8_t * scratch1,int16_t * scratch2,int16_t * scratch3,int16_t * scratch4,int16_t * scratch5,int16_t * scratch6,int16_t * scratch7)1569 inline void LstmStepInteger8x8_8(
1570     const int8_t* input_ptr, int32_t input_zp,
1571     const int8_t* input_to_input_weight_ptr,
1572     int32_t effective_input_to_input_scale_a,
1573     int32_t effective_input_to_input_scale_b,
1574     const int8_t* input_to_forget_weight_ptr,
1575     int32_t effective_input_to_forget_scale_a,
1576     int32_t effective_input_to_forget_scale_b,
1577     const int8_t* input_to_cell_weight_ptr,
1578     int32_t effective_input_to_cell_scale_a,
1579     int32_t effective_input_to_cell_scale_b,
1580     const int8_t* input_to_output_weight_ptr,
1581     int32_t effective_input_to_output_scale_a,
1582     int32_t effective_input_to_output_scale_b,
1583     const int8_t* recurrent_to_input_weight_ptr,
1584     int32_t effective_recurrent_to_input_scale_a,
1585     int32_t effective_recurrent_to_input_scale_b,
1586     const int8_t* recurrent_to_forget_weight_ptr,
1587     int32_t effective_recurrent_to_forget_scale_a,
1588     int32_t effective_recurrent_to_forget_scale_b,
1589     const int8_t* recurrent_to_cell_weight_ptr,
1590     int32_t effective_recurrent_to_cell_scale_a,
1591     int32_t effective_recurrent_to_cell_scale_b,
1592     const int8_t* recurrent_to_output_weight_ptr,
1593     int32_t effective_recurrent_to_output_scale_a,
1594     int32_t effective_recurrent_to_output_scale_b,
1595     const int8_t* cell_to_input_weight_ptr,
1596     int32_t effective_cell_to_input_scale_a,
1597     int32_t effective_cell_to_input_scale_b,
1598     const int8_t* cell_to_forget_weight_ptr,
1599     int32_t effective_cell_to_forget_scale_a,
1600     int32_t effective_cell_to_forget_scale_b,
1601     const int8_t* cell_to_output_weight_ptr,
1602     int32_t effective_cell_to_output_scale_a,
1603     int32_t effective_cell_to_output_scale_b,
1604     const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1605     int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
1606     int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1607     const int16_t* layer_norm_forget_weight_ptr,
1608     int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1609     const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1610     int32_t layer_norm_cell_scale_b,
1611     const int16_t* layer_norm_output_weight_ptr,
1612     int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1613     const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1614     const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1615     const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
1616     const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
1617     const int32_t* intermediate_zp, int16_t quantized_cell_clip,
1618     int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
1619     int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
1620     int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1621     int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1622     int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
1623     int16_t* scratch7) {
1624   // TODO(b/159066113): scratch5 is unused, remove.
1625 
1626   ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
1627   // Make named scratch buffers for the different gates.
1628   int16_t* forget_gate_scratch = scratch2;
1629   int16_t* cell_gate_scratch = scratch3;
1630   int16_t* output_gate_scratch = scratch4;
1631   // no-CIFG is not supported here
1632 
1633   // Calculate the forget gate.
1634   CalculateLstmGateInteger8x8_8(
1635       input_ptr, input_zp, input_to_forget_weight_ptr,
1636       effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1637       intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
1638       output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
1639       effective_recurrent_to_forget_scale_a,
1640       effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
1641       intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
1642       layer_norm_forget_scale_a, layer_norm_forget_scale_b,
1643       forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1644       kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
1645   // Calculate the cell update gate.
1646   CalculateLstmGateInteger8x8_8(
1647       input_ptr, input_zp, input_to_cell_weight_ptr,
1648       effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1649       intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
1650       output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
1651       effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
1652       intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
1653       layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
1654       layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
1655       n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
1656   // Update the cell state.
1657   UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
1658                         /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
1659                         forget_gate_scratch, cell_gate_scratch,
1660                         /*use_cifg=*/true, quantized_cell_clip);
1661   // Calculate the output gate.
1662   CalculateLstmGateInteger8x8_8(
1663       input_ptr, input_zp, input_to_output_weight_ptr,
1664       effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1665       intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
1666       output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
1667       effective_recurrent_to_output_scale_a,
1668       effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
1669       intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
1670       layer_norm_output_scale_a, layer_norm_output_scale_b,
1671       output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1672       kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
1673   // Update the output state.
1674   CalculateLstmOutputInteger8x8_8(
1675       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1676       projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
1677       projection_bias_ptr, output_state_zp, quantized_proj_clip,
1678       output_state_ptr, scratch2);
1679   // Copy output state to the output. Note that unlike float or hybrid, output
1680   // is always contigous.
1681   std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1682 }
1683 
1684 }  // namespace
1685 
1686 // LINT.IfChange
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output)1687 TfLiteStatus EvalFloat(
1688     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1689     const TfLiteTensor* input_to_forget_weights,
1690     const TfLiteTensor* input_to_cell_weights,
1691     const TfLiteTensor* input_to_output_weights,
1692     const TfLiteTensor* recurrent_to_input_weights,
1693     const TfLiteTensor* recurrent_to_forget_weights,
1694     const TfLiteTensor* recurrent_to_cell_weights,
1695     const TfLiteTensor* recurrent_to_output_weights,
1696     const TfLiteTensor* cell_to_input_weights,
1697     const TfLiteTensor* cell_to_forget_weights,
1698     const TfLiteTensor* cell_to_output_weights,
1699     const TfLiteTensor* input_layer_norm_coefficients,
1700     const TfLiteTensor* forget_layer_norm_coefficients,
1701     const TfLiteTensor* cell_layer_norm_coefficients,
1702     const TfLiteTensor* output_layer_norm_coefficients,
1703     const TfLiteTensor* aux_input,
1704     const TfLiteTensor* aux_input_to_input_weights,
1705     const TfLiteTensor* aux_input_to_forget_weights,
1706     const TfLiteTensor* aux_input_to_cell_weights,
1707     const TfLiteTensor* aux_input_to_output_weights,
1708     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1709     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1710     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1711     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1712     int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
1713     TfLiteTensor* cell_state, TfLiteTensor* output) {
1714   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1715   int max_time, n_batch;
1716   if (input->dims->size == 3) {
1717     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1718     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1719   } else {
1720     max_time = 1;
1721     n_batch = input->dims->data[0];
1722   }
1723   const int n_input = input->dims->data[input->dims->size - 1];
1724   const int aux_input_size =
1725       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1726 
1727   // n_cell and n_output will be the same size when there is no projection.
1728   const int n_cell = input_to_output_weights->dims->data[0];
1729   const int n_output = recurrent_to_output_weights->dims->data[1];
1730 
1731   // Since we have already checked that weights are all there or none, we can
1732   // check the existence of only one to the get the condition.
1733   const bool use_cifg = (input_to_input_weights == nullptr);
1734 
1735   // Index the scratch buffers pointers to the global scratch buffer.
1736   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1737   float* input_gate_scratch = nullptr;
1738   float* cell_gate_scratch = nullptr;
1739   float* forget_gate_scratch = nullptr;
1740   float* output_gate_scratch = nullptr;
1741   if (use_cifg) {
1742     cell_gate_scratch = scratch_buffer_ptr;
1743     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1744     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1745   } else {
1746     input_gate_scratch = scratch_buffer_ptr;
1747     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1748     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1749     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1750   }
1751 
1752   const int output_batch_leading_dim =
1753       output->dims->data[output->dims->size - 1];
1754   if (time_major) {
1755     // Loop through the sequence.
1756     const int input_step = n_batch * n_input;
1757     const int output_step = n_batch * output_batch_leading_dim;
1758     for (int t = 0; t < max_time; t++) {
1759       // If this is the forward_sequence, step forward, otherwise step
1760       // backwards.
1761       const int t_rel = forward_sequence ? t : max_time - t - 1;
1762       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1763       const float* aux_input_ptr = nullptr;
1764       if (aux_input) {
1765         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1766       }
1767       float* output_ptr =
1768           GetTensorData<float>(output) + t_rel * output_step + output_offset;
1769 
1770       LstmStepFloat(
1771           input_ptr, GetTensorData<float>(input_to_input_weights),
1772           GetTensorData<float>(input_to_forget_weights),
1773           GetTensorData<float>(input_to_cell_weights),
1774           GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1775           GetTensorData<float>(aux_input_to_input_weights),
1776           GetTensorData<float>(aux_input_to_forget_weights),
1777           GetTensorData<float>(aux_input_to_cell_weights),
1778           GetTensorData<float>(aux_input_to_output_weights),
1779           GetTensorData<float>(recurrent_to_input_weights),
1780           GetTensorData<float>(recurrent_to_forget_weights),
1781           GetTensorData<float>(recurrent_to_cell_weights),
1782           GetTensorData<float>(recurrent_to_output_weights),
1783           GetTensorData<float>(cell_to_input_weights),
1784           GetTensorData<float>(cell_to_forget_weights),
1785           GetTensorData<float>(cell_to_output_weights),
1786           GetTensorData<float>(input_layer_norm_coefficients),
1787           GetTensorData<float>(forget_layer_norm_coefficients),
1788           GetTensorData<float>(cell_layer_norm_coefficients),
1789           GetTensorData<float>(output_layer_norm_coefficients),
1790           GetTensorData<float>(input_gate_bias),
1791           GetTensorData<float>(forget_gate_bias),
1792           GetTensorData<float>(cell_gate_bias),
1793           GetTensorData<float>(output_gate_bias),
1794           GetTensorData<float>(projection_weights),
1795           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
1796           n_input, aux_input_size, n_output, output_batch_leading_dim,
1797           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
1798           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
1799           output_gate_scratch, output_ptr);
1800     }
1801   } else {
1802     for (int b = 0; b < n_batch; b++) {
1803       const int input_step = n_input;
1804       const int output_step = output_batch_leading_dim;
1805       for (int t = 0; t < max_time; t++) {
1806         // If this is the forward_sequence, step forward, otherwise step
1807         // backwards.
1808         const int t_rel = forward_sequence ? t : max_time - t - 1;
1809         const int time_offset = b * max_time + t_rel;
1810         const float* input_ptr =
1811             GetTensorData<float>(input) + time_offset * input_step;
1812         const float* aux_input_ptr = nullptr;
1813         if (aux_input) {
1814           aux_input_ptr =
1815               GetTensorData<float>(aux_input) + time_offset * input_step;
1816         }
1817         float* output_ptr = GetTensorData<float>(output) +
1818                             time_offset * output_step + output_offset;
1819 
1820         // Offset the {output,cell}_state pointers to the right batch.
1821         float* output_state_ptr =
1822             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
1823         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
1824         // Offset the scratch pointers to the right batch.
1825         float* input_gate_scratch_ptr =
1826             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1827         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1828         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
1829         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1830 
1831         LstmStepFloat(
1832             input_ptr, GetTensorData<float>(input_to_input_weights),
1833             GetTensorData<float>(input_to_forget_weights),
1834             GetTensorData<float>(input_to_cell_weights),
1835             GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1836             GetTensorData<float>(aux_input_to_input_weights),
1837             GetTensorData<float>(aux_input_to_forget_weights),
1838             GetTensorData<float>(aux_input_to_cell_weights),
1839             GetTensorData<float>(aux_input_to_output_weights),
1840             GetTensorData<float>(recurrent_to_input_weights),
1841             GetTensorData<float>(recurrent_to_forget_weights),
1842             GetTensorData<float>(recurrent_to_cell_weights),
1843             GetTensorData<float>(recurrent_to_output_weights),
1844             GetTensorData<float>(cell_to_input_weights),
1845             GetTensorData<float>(cell_to_forget_weights),
1846             GetTensorData<float>(cell_to_output_weights),
1847             GetTensorData<float>(input_layer_norm_coefficients),
1848             GetTensorData<float>(forget_layer_norm_coefficients),
1849             GetTensorData<float>(cell_layer_norm_coefficients),
1850             GetTensorData<float>(output_layer_norm_coefficients),
1851             GetTensorData<float>(input_gate_bias),
1852             GetTensorData<float>(forget_gate_bias),
1853             GetTensorData<float>(cell_gate_bias),
1854             GetTensorData<float>(output_gate_bias),
1855             GetTensorData<float>(projection_weights),
1856             GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
1857             n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1858             output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1859             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
1860             output_gate_scratch_ptr, output_ptr);
1861       }
1862     }
1863   }
1864   return kTfLiteOk;
1865 }
1866 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1867 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_input_weights_ledger,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_forget_weights_ledger,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_cell_weights_ledger,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * input_to_output_weights_ledger,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_input_weights_ledger,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_forget_weights_ledger,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_cell_weights_ledger,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * recurrent_to_output_weights_ledger,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_weights_ledger,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * input_sf,TfLiteTensor * aux_input_sf,TfLiteTensor * output_state_sf,TfLiteTensor * prod_scaling_factors,TfLiteTensor * recovered_cell_weights,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * output_state_quantized,TfLiteTensor * cell_state_quantized,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output_scratch_buffer,TfLiteTensor * output,TfLiteTensor * input_zp,TfLiteTensor * aux_input_zp,TfLiteTensor * output_state_zp,TfLiteTensor * row_sums,int row_sums_size,bool * compute_row_sums,CpuBackendContext * context)1868 TfLiteStatus EvalHybrid(
1869     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1870     const TfLiteTensor* input_to_input_weights_ledger,
1871     const TfLiteTensor* input_to_forget_weights,
1872     const TfLiteTensor* input_to_forget_weights_ledger,
1873     const TfLiteTensor* input_to_cell_weights,
1874     const TfLiteTensor* input_to_cell_weights_ledger,
1875     const TfLiteTensor* input_to_output_weights,
1876     const TfLiteTensor* input_to_output_weights_ledger,
1877     const TfLiteTensor* recurrent_to_input_weights,
1878     const TfLiteTensor* recurrent_to_input_weights_ledger,
1879     const TfLiteTensor* recurrent_to_forget_weights,
1880     const TfLiteTensor* recurrent_to_forget_weights_ledger,
1881     const TfLiteTensor* recurrent_to_cell_weights,
1882     const TfLiteTensor* recurrent_to_cell_weights_ledger,
1883     const TfLiteTensor* recurrent_to_output_weights,
1884     const TfLiteTensor* recurrent_to_output_weights_ledger,
1885     const TfLiteTensor* cell_to_input_weights,
1886     const TfLiteTensor* cell_to_forget_weights,
1887     const TfLiteTensor* cell_to_output_weights,
1888     const TfLiteTensor* input_layer_norm_coefficients,
1889     const TfLiteTensor* forget_layer_norm_coefficients,
1890     const TfLiteTensor* cell_layer_norm_coefficients,
1891     const TfLiteTensor* output_layer_norm_coefficients,
1892     const TfLiteTensor* aux_input,
1893     const TfLiteTensor* aux_input_to_input_weights,
1894     const TfLiteTensor* aux_input_to_forget_weights,
1895     const TfLiteTensor* aux_input_to_cell_weights,
1896     const TfLiteTensor* aux_input_to_output_weights,
1897     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1898     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1899     const TfLiteTensor* projection_weights,
1900     const TfLiteTensor* projection_weights_ledger,
1901     const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
1902     bool forward_sequence, bool time_major, int output_offset,
1903     TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
1904     TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
1905     TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
1906     TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
1907     TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
1908     TfLiteTensor* output_state, TfLiteTensor* cell_state,
1909     TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
1910     TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
1911     TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
1912     bool* compute_row_sums, CpuBackendContext* context) {
1913   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1914   const int n_input = input->dims->data[input->dims->size - 1];
1915   int max_time, n_batch;
1916   if (input->dims->size == 2) {
1917     max_time = 1;
1918     n_batch = input->dims->data[0];
1919   } else {
1920     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1921     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1922   }
1923   const int aux_input_size =
1924       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1925   // n_cell and n_output will be the same size when there is no projection.
1926   const int n_cell = input_to_output_weights->dims->data[0];
1927   const int n_output = recurrent_to_output_weights->dims->data[1];
1928 
1929   // Since we have already checked that weights are all there or none, we can
1930   // check the existence of only one to get the condition.
1931   const bool use_cifg = (input_to_input_weights == nullptr);
1932 
1933   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1934   float* input_gate_scratch = nullptr;
1935   float* cell_gate_scratch = nullptr;
1936   float* forget_gate_scratch = nullptr;
1937   float* output_gate_scratch = nullptr;
1938   if (use_cifg) {
1939     cell_gate_scratch = scratch_buffer_ptr;
1940     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1941     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1942   } else {
1943     input_gate_scratch = scratch_buffer_ptr;
1944     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1945     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1946     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1947   }
1948 
1949   const int output_batch_leading_dim =
1950       output->dims->data[output->dims->size - 1];
1951 
1952   int32_t* input_zp_ptr = nullptr;
1953   int32_t* aux_input_zp_ptr = nullptr;
1954   int32_t* output_state_zp_ptr = nullptr;
1955   int32_t* row_sums_ptr = nullptr;
1956   if (params->asymmetric_quantize_inputs) {
1957     input_zp_ptr = GetTensorData<int32_t>(input_zp);
1958     aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
1959     output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
1960     row_sums_ptr = GetTensorData<int32_t>(row_sums);
1961   }
1962 
1963   if (time_major) {
1964     // Feed the sequence into the LSTM step-by-step.
1965     const int input_step = n_batch * n_input;
1966     const int output_step = n_batch * output_batch_leading_dim;
1967     for (int t = 0; t < max_time; t++) {
1968       // If this is the forward_sequence, step forward, otherwise step
1969       // backwards.
1970       const int t_rel = forward_sequence ? t : max_time - t - 1;
1971       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1972       const float* aux_input_ptr = nullptr;
1973       if (aux_input) {
1974         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1975       }
1976       float* output_ptr =
1977           GetTensorData<float>(output) + t_rel * output_step + output_offset;
1978       LstmStepHybrid(
1979           input_ptr, GetTensorData<int8_t>(input_to_input_weights),
1980           GetTensorData<uint8_t>(input_to_input_weights_ledger),
1981           GetTensorScale(input_to_input_weights),
1982           GetTensorData<int8_t>(input_to_forget_weights),
1983           GetTensorData<uint8_t>(input_to_forget_weights_ledger),
1984           GetTensorScale(input_to_forget_weights),
1985           GetTensorData<int8_t>(input_to_cell_weights),
1986           GetTensorData<uint8_t>(input_to_cell_weights_ledger),
1987           GetTensorScale(input_to_cell_weights),
1988           GetTensorData<int8_t>(input_to_output_weights),
1989           GetTensorData<uint8_t>(input_to_output_weights_ledger),
1990           GetTensorScale(input_to_output_weights), aux_input_ptr,
1991           GetTensorData<int8_t>(aux_input_to_input_weights),
1992           GetTensorScale(aux_input_to_input_weights),
1993           GetTensorData<int8_t>(aux_input_to_forget_weights),
1994           GetTensorScale(aux_input_to_forget_weights),
1995           GetTensorData<int8_t>(aux_input_to_cell_weights),
1996           GetTensorScale(aux_input_to_cell_weights),
1997           GetTensorData<int8_t>(aux_input_to_output_weights),
1998           GetTensorScale(aux_input_to_output_weights),
1999           GetTensorData<int8_t>(recurrent_to_input_weights),
2000           GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2001           GetTensorScale(recurrent_to_input_weights),
2002           GetTensorData<int8_t>(recurrent_to_forget_weights),
2003           GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2004           GetTensorScale(recurrent_to_forget_weights),
2005           GetTensorData<int8_t>(recurrent_to_cell_weights),
2006           GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2007           GetTensorScale(recurrent_to_cell_weights),
2008           GetTensorData<int8_t>(recurrent_to_output_weights),
2009           GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2010           GetTensorScale(recurrent_to_output_weights),
2011           GetTensorData<int8_t>(cell_to_input_weights),
2012           GetTensorScale(cell_to_input_weights),
2013           GetTensorData<int8_t>(cell_to_forget_weights),
2014           GetTensorScale(cell_to_forget_weights),
2015           GetTensorData<int8_t>(cell_to_output_weights),
2016           GetTensorScale(cell_to_output_weights),
2017           GetTensorData<float>(input_layer_norm_coefficients),
2018           GetTensorData<float>(forget_layer_norm_coefficients),
2019           GetTensorData<float>(cell_layer_norm_coefficients),
2020           GetTensorData<float>(output_layer_norm_coefficients),
2021           GetTensorData<float>(input_gate_bias),
2022           GetTensorData<float>(forget_gate_bias),
2023           GetTensorData<float>(cell_gate_bias),
2024           GetTensorData<float>(output_gate_bias),
2025           GetTensorData<int8_t>(projection_weights),
2026           GetTensorData<uint8_t>(projection_weights_ledger),
2027           GetTensorScale(projection_weights),
2028           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
2029           n_input, aux_input_size, n_output, output_batch_leading_dim,
2030           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
2031           output_gate_scratch, GetTensorData<float>(input_sf),
2032           GetTensorData<float>(aux_input_sf),
2033           GetTensorData<float>(output_state_sf),
2034           GetTensorData<float>(prod_scaling_factors),
2035           GetTensorData<float>(recovered_cell_weights),
2036           GetTensorData<int8_t>(input_quantized),
2037           GetTensorData<int8_t>(aux_input_quantized),
2038           GetTensorData<int8_t>(output_state_quantized),
2039           GetTensorData<int8_t>(cell_state_quantized),
2040           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
2041           GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
2042           input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
2043           row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
2044           context);
2045     }
2046   } else {
2047     for (int b = 0; b < n_batch; b++) {
2048       const int input_step = n_input;
2049       const int output_step = output_batch_leading_dim;
2050       for (int t = 0; t < max_time; t++) {
2051         // If this is the forward_sequence, step forward, otherwise step
2052         // backwards.
2053         const int t_rel = forward_sequence ? t : max_time - t - 1;
2054         const int time_offset = b * max_time + t_rel;
2055         const float* input_ptr =
2056             GetTensorData<float>(input) + time_offset * input_step;
2057         const float* aux_input_ptr = nullptr;
2058         if (aux_input) {
2059           aux_input_ptr =
2060               GetTensorData<float>(aux_input) + time_offset * input_step;
2061         }
2062         float* output_ptr = GetTensorData<float>(output) +
2063                             time_offset * output_step + output_offset;
2064 
2065         // Offset the {output,cell}_state pointers to the right batch.
2066         float* output_state_ptr =
2067             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
2068         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
2069         // Offset the scratch pointers to the right batch.
2070         float* input_gate_scratch_ptr =
2071             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
2072         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
2073         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
2074         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
2075 
2076         LstmStepHybrid(
2077             input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2078             GetTensorData<uint8_t>(input_to_input_weights_ledger),
2079             GetTensorScale(input_to_input_weights),
2080             GetTensorData<int8_t>(input_to_forget_weights),
2081             GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2082             GetTensorScale(input_to_forget_weights),
2083             GetTensorData<int8_t>(input_to_cell_weights),
2084             GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2085             GetTensorScale(input_to_cell_weights),
2086             GetTensorData<int8_t>(input_to_output_weights),
2087             GetTensorData<uint8_t>(input_to_output_weights_ledger),
2088             GetTensorScale(input_to_output_weights), aux_input_ptr,
2089             GetTensorData<int8_t>(aux_input_to_input_weights),
2090             GetTensorScale(aux_input_to_input_weights),
2091             GetTensorData<int8_t>(aux_input_to_forget_weights),
2092             GetTensorScale(aux_input_to_forget_weights),
2093             GetTensorData<int8_t>(aux_input_to_cell_weights),
2094             GetTensorScale(aux_input_to_cell_weights),
2095             GetTensorData<int8_t>(aux_input_to_output_weights),
2096             GetTensorScale(aux_input_to_output_weights),
2097             GetTensorData<int8_t>(recurrent_to_input_weights),
2098             GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2099             GetTensorScale(recurrent_to_input_weights),
2100             GetTensorData<int8_t>(recurrent_to_forget_weights),
2101             GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2102             GetTensorScale(recurrent_to_forget_weights),
2103             GetTensorData<int8_t>(recurrent_to_cell_weights),
2104             GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2105             GetTensorScale(recurrent_to_cell_weights),
2106             GetTensorData<int8_t>(recurrent_to_output_weights),
2107             GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2108             GetTensorScale(recurrent_to_output_weights),
2109             GetTensorData<int8_t>(cell_to_input_weights),
2110             GetTensorScale(cell_to_input_weights),
2111             GetTensorData<int8_t>(cell_to_forget_weights),
2112             GetTensorScale(cell_to_forget_weights),
2113             GetTensorData<int8_t>(cell_to_output_weights),
2114             GetTensorScale(cell_to_output_weights),
2115             GetTensorData<float>(input_layer_norm_coefficients),
2116             GetTensorData<float>(forget_layer_norm_coefficients),
2117             GetTensorData<float>(cell_layer_norm_coefficients),
2118             GetTensorData<float>(output_layer_norm_coefficients),
2119             GetTensorData<float>(input_gate_bias),
2120             GetTensorData<float>(forget_gate_bias),
2121             GetTensorData<float>(cell_gate_bias),
2122             GetTensorData<float>(output_gate_bias),
2123             GetTensorData<int8_t>(projection_weights),
2124             GetTensorData<uint8_t>(projection_weights_ledger),
2125             GetTensorScale(projection_weights),
2126             GetTensorData<float>(projection_bias), params,
2127             /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
2128             output_batch_leading_dim, input_gate_scratch_ptr,
2129             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
2130             output_gate_scratch_ptr, GetTensorData<float>(input_sf),
2131             GetTensorData<float>(aux_input_sf),
2132             GetTensorData<float>(output_state_sf),
2133             GetTensorData<float>(prod_scaling_factors),
2134             GetTensorData<float>(recovered_cell_weights),
2135             GetTensorData<int8_t>(input_quantized),
2136             GetTensorData<int8_t>(aux_input_quantized),
2137             GetTensorData<int8_t>(output_state_quantized),
2138             GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
2139             cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
2140             output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
2141             row_sums_ptr, row_sums_size, compute_row_sums,
2142             params->asymmetric_quantize_inputs, context);
2143       }
2144     }
2145   }
2146 
2147   return kTfLiteOk;
2148 }
2149 
EvalInteger8x8_16(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,CpuBackendContext * context)2150 TfLiteStatus EvalInteger8x8_16(
2151     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2152     const TfLiteTensor* input_to_forget_weights,
2153     const TfLiteTensor* input_to_cell_weights,
2154     const TfLiteTensor* input_to_output_weights,
2155     const TfLiteTensor* recurrent_to_input_weights,
2156     const TfLiteTensor* recurrent_to_forget_weights,
2157     const TfLiteTensor* recurrent_to_cell_weights,
2158     const TfLiteTensor* recurrent_to_output_weights,
2159     const TfLiteTensor* cell_to_input_weights,
2160     const TfLiteTensor* cell_to_forget_weights,
2161     const TfLiteTensor* cell_to_output_weights,
2162     const TfLiteTensor* input_layer_norm_coefficients,
2163     const TfLiteTensor* forget_layer_norm_coefficients,
2164     const TfLiteTensor* cell_layer_norm_coefficients,
2165     const TfLiteTensor* output_layer_norm_coefficients,
2166     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2167     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2168     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2169     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
2170     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2171     TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
2172     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2173     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2174     CpuBackendContext* context) {
2175   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2176   const int n_input = input->dims->data[input->dims->size - 1];
2177   int max_time, n_batch;
2178   if (input->dims->size == 2) {
2179     max_time = 1;
2180     n_batch = input->dims->data[0];
2181   } else {
2182     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
2183     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
2184   }
2185 
2186   // n_cell and n_output will be the same size when there is no projection.
2187   const int n_cell = input_to_output_weights->dims->data[0];
2188   const int n_output = recurrent_to_output_weights->dims->data[1];
2189 
2190   // Activation zero point
2191   int output_state_zp = output_state->params.zero_point;
2192 
2193   // Get params for time/batch/sequence.
2194   const int output_batch_leading_dim =
2195       output->dims->data[output->dims->size - 1];
2196 
2197   if (time_major) {
2198     const int input_step = n_batch * n_input;
2199     const int output_step = n_batch * output_batch_leading_dim;
2200     for (int t = 0; t < max_time; t++) {
2201       const int t_rel = t;
2202       int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2203       const int8_t* input_ptr =
2204           GetTensorData<int8_t>(input) + t_rel * input_step;
2205       LstmStepInteger8x8_16(
2206           input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2207           integer_lstm_param->effective_input_to_input_scale_a,
2208           integer_lstm_param->effective_input_to_input_scale_b,
2209           GetTensorData<int8_t>(input_to_forget_weights),
2210           integer_lstm_param->effective_input_to_forget_scale_a,
2211           integer_lstm_param->effective_input_to_forget_scale_b,
2212           GetTensorData<int8_t>(input_to_cell_weights),
2213           integer_lstm_param->effective_input_to_cell_scale_a,
2214           integer_lstm_param->effective_input_to_cell_scale_b,
2215           GetTensorData<int8_t>(input_to_output_weights),
2216           integer_lstm_param->effective_input_to_output_scale_a,
2217           integer_lstm_param->effective_input_to_output_scale_b,
2218           GetTensorData<int8_t>(recurrent_to_input_weights),
2219           integer_lstm_param->effective_recurrent_to_input_scale_a,
2220           integer_lstm_param->effective_recurrent_to_input_scale_b,
2221           GetTensorData<int8_t>(recurrent_to_forget_weights),
2222           integer_lstm_param->effective_recurrent_to_forget_scale_a,
2223           integer_lstm_param->effective_recurrent_to_forget_scale_b,
2224           GetTensorData<int8_t>(recurrent_to_cell_weights),
2225           integer_lstm_param->effective_recurrent_to_cell_scale_a,
2226           integer_lstm_param->effective_recurrent_to_cell_scale_b,
2227           GetTensorData<int8_t>(recurrent_to_output_weights),
2228           integer_lstm_param->effective_recurrent_to_output_scale_a,
2229           integer_lstm_param->effective_recurrent_to_output_scale_b,
2230           GetTensorData<int16_t>(cell_to_input_weights),
2231           integer_lstm_param->effective_cell_to_input_scale_a,
2232           integer_lstm_param->effective_cell_to_input_scale_b,
2233           GetTensorData<int16_t>(cell_to_forget_weights),
2234           integer_lstm_param->effective_cell_to_forget_scale_a,
2235           integer_lstm_param->effective_cell_to_forget_scale_b,
2236           GetTensorData<int16_t>(cell_to_output_weights),
2237           integer_lstm_param->effective_cell_to_output_scale_a,
2238           integer_lstm_param->effective_cell_to_output_scale_b,
2239           GetTensorData<int8_t>(projection_weights),
2240           integer_lstm_param->effective_proj_scale_a,
2241           integer_lstm_param->effective_proj_scale_b,
2242           integer_lstm_param->hidden_zp,
2243           integer_lstm_param->effective_hidden_scale_a,
2244           integer_lstm_param->effective_hidden_scale_b,
2245           GetTensorData<int16_t>(input_layer_norm_coefficients),
2246           integer_lstm_param->layer_norm_input_scale_a,
2247           integer_lstm_param->layer_norm_input_scale_b,
2248           GetTensorData<int16_t>(forget_layer_norm_coefficients),
2249           integer_lstm_param->layer_norm_forget_scale_a,
2250           integer_lstm_param->layer_norm_forget_scale_b,
2251           GetTensorData<int16_t>(cell_layer_norm_coefficients),
2252           integer_lstm_param->layer_norm_cell_scale_a,
2253           integer_lstm_param->layer_norm_cell_scale_b,
2254           GetTensorData<int16_t>(output_layer_norm_coefficients),
2255           integer_lstm_param->layer_norm_output_scale_a,
2256           integer_lstm_param->layer_norm_output_scale_b,
2257           GetTensorData<int32_t>(input_gate_bias),
2258           GetTensorData<int32_t>(forget_gate_bias),
2259           GetTensorData<int32_t>(cell_gate_bias),
2260           GetTensorData<int32_t>(output_gate_bias),
2261           integer_lstm_param->quantized_cell_clip,
2262           integer_lstm_param->quantized_proj_clip,
2263           integer_lstm_param->cell_scale,
2264           integer_lstm_param->input_variance_guard,
2265           integer_lstm_param->forget_variance_guard,
2266           integer_lstm_param->cell_variance_guard,
2267           integer_lstm_param->output_variance_guard,
2268           integer_lstm_param->input_to_forget_effective_bias.get(),
2269           integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2270           integer_lstm_param->input_to_cell_effective_bias.get(),
2271           integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2272           integer_lstm_param->input_to_output_effective_bias.get(),
2273           integer_lstm_param->recurrent_to_output_effective_bias.get(),
2274           integer_lstm_param->input_to_input_effective_bias.get(),
2275           integer_lstm_param->recurrent_to_input_effective_bias.get(),
2276           integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
2277           n_input, n_output, GetTensorData<int8_t>(output_state),
2278           output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2279           GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
2280           GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2281           GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
2282           context);
2283     }
2284   } else {
2285     for (int b = 0; b < n_batch; b++) {
2286       const int input_step = n_input;
2287       const int output_step = output_batch_leading_dim;
2288       for (int t = 0; t < max_time; t++) {
2289         // If this is the forward_sequence, step forward, otherwise step
2290         // backwards.
2291         const int t_rel = forward_sequence ? t : max_time - t - 1;
2292         const int time_offset = b * max_time + t_rel;
2293         const int8_t* input_ptr =
2294             GetTensorData<int8_t>(input) + time_offset * input_step;
2295         int8_t* output_ptr =
2296             GetTensorData<int8_t>(output) + time_offset * output_step;
2297 
2298         // Offset the {output,cell}_state pointers to the right batch.
2299         int8_t* output_state_ptr =
2300             GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim;
2301         int16_t* cell_state_ptr =
2302             GetTensorData<int16_t>(cell_state) + b * n_cell;
2303 
2304         LstmStepInteger8x8_16(
2305             input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2306             integer_lstm_param->effective_input_to_input_scale_a,
2307             integer_lstm_param->effective_input_to_input_scale_b,
2308             GetTensorData<int8_t>(input_to_forget_weights),
2309             integer_lstm_param->effective_input_to_forget_scale_a,
2310             integer_lstm_param->effective_input_to_forget_scale_b,
2311             GetTensorData<int8_t>(input_to_cell_weights),
2312             integer_lstm_param->effective_input_to_cell_scale_a,
2313             integer_lstm_param->effective_input_to_cell_scale_b,
2314             GetTensorData<int8_t>(input_to_output_weights),
2315             integer_lstm_param->effective_input_to_output_scale_a,
2316             integer_lstm_param->effective_input_to_output_scale_b,
2317             GetTensorData<int8_t>(recurrent_to_input_weights),
2318             integer_lstm_param->effective_recurrent_to_input_scale_a,
2319             integer_lstm_param->effective_recurrent_to_input_scale_b,
2320             GetTensorData<int8_t>(recurrent_to_forget_weights),
2321             integer_lstm_param->effective_recurrent_to_forget_scale_a,
2322             integer_lstm_param->effective_recurrent_to_forget_scale_b,
2323             GetTensorData<int8_t>(recurrent_to_cell_weights),
2324             integer_lstm_param->effective_recurrent_to_cell_scale_a,
2325             integer_lstm_param->effective_recurrent_to_cell_scale_b,
2326             GetTensorData<int8_t>(recurrent_to_output_weights),
2327             integer_lstm_param->effective_recurrent_to_output_scale_a,
2328             integer_lstm_param->effective_recurrent_to_output_scale_b,
2329             GetTensorData<int16_t>(cell_to_input_weights),
2330             integer_lstm_param->effective_cell_to_input_scale_a,
2331             integer_lstm_param->effective_cell_to_input_scale_b,
2332             GetTensorData<int16_t>(cell_to_forget_weights),
2333             integer_lstm_param->effective_cell_to_forget_scale_a,
2334             integer_lstm_param->effective_cell_to_forget_scale_b,
2335             GetTensorData<int16_t>(cell_to_output_weights),
2336             integer_lstm_param->effective_cell_to_output_scale_a,
2337             integer_lstm_param->effective_cell_to_output_scale_b,
2338             GetTensorData<int8_t>(projection_weights),
2339             integer_lstm_param->effective_proj_scale_a,
2340             integer_lstm_param->effective_proj_scale_b,
2341             integer_lstm_param->hidden_zp,
2342             integer_lstm_param->effective_hidden_scale_a,
2343             integer_lstm_param->effective_hidden_scale_b,
2344             GetTensorData<int16_t>(input_layer_norm_coefficients),
2345             integer_lstm_param->layer_norm_input_scale_a,
2346             integer_lstm_param->layer_norm_input_scale_b,
2347             GetTensorData<int16_t>(forget_layer_norm_coefficients),
2348             integer_lstm_param->layer_norm_forget_scale_a,
2349             integer_lstm_param->layer_norm_forget_scale_b,
2350             GetTensorData<int16_t>(cell_layer_norm_coefficients),
2351             integer_lstm_param->layer_norm_cell_scale_a,
2352             integer_lstm_param->layer_norm_cell_scale_b,
2353             GetTensorData<int16_t>(output_layer_norm_coefficients),
2354             integer_lstm_param->layer_norm_output_scale_a,
2355             integer_lstm_param->layer_norm_output_scale_b,
2356             GetTensorData<int32_t>(input_gate_bias),
2357             GetTensorData<int32_t>(forget_gate_bias),
2358             GetTensorData<int32_t>(cell_gate_bias),
2359             GetTensorData<int32_t>(output_gate_bias),
2360             integer_lstm_param->quantized_cell_clip,
2361             integer_lstm_param->quantized_proj_clip,
2362             integer_lstm_param->cell_scale,
2363             integer_lstm_param->input_variance_guard,
2364             integer_lstm_param->forget_variance_guard,
2365             integer_lstm_param->cell_variance_guard,
2366             integer_lstm_param->output_variance_guard,
2367             integer_lstm_param->input_to_forget_effective_bias.get(),
2368             integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2369             integer_lstm_param->input_to_cell_effective_bias.get(),
2370             integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2371             integer_lstm_param->input_to_output_effective_bias.get(),
2372             integer_lstm_param->recurrent_to_output_effective_bias.get(),
2373             integer_lstm_param->input_to_input_effective_bias.get(),
2374             integer_lstm_param->recurrent_to_input_effective_bias.get(),
2375             integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
2376             n_cell, n_input, n_output, output_state_ptr, output_state_zp,
2377             cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0),
2378             GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2),
2379             GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4),
2380             GetTensorData<int32_t>(scratch5), context);
2381       }
2382     }
2383   }
2384 
2385   return kTfLiteOk;
2386 }
2387 
EvalInteger8x8_8(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,TfLiteTensor * scratch6,TfLiteTensor * scratch7)2388 TfLiteStatus EvalInteger8x8_8(
2389     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2390     const TfLiteTensor* input_to_forget_weights,
2391     const TfLiteTensor* input_to_cell_weights,
2392     const TfLiteTensor* input_to_output_weights,
2393     const TfLiteTensor* recurrent_to_input_weights,
2394     const TfLiteTensor* recurrent_to_forget_weights,
2395     const TfLiteTensor* recurrent_to_cell_weights,
2396     const TfLiteTensor* recurrent_to_output_weights,
2397     const TfLiteTensor* cell_to_input_weights,
2398     const TfLiteTensor* cell_to_forget_weights,
2399     const TfLiteTensor* cell_to_output_weights,
2400     const TfLiteTensor* input_layer_norm_coefficients,
2401     const TfLiteTensor* forget_layer_norm_coefficients,
2402     const TfLiteTensor* cell_layer_norm_coefficients,
2403     const TfLiteTensor* output_layer_norm_coefficients,
2404     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2405     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2406     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2407     const TfLiteLSTMParams* params, TfLiteTensor* output_state,
2408     TfLiteTensor* cell_state, TfLiteTensor* output,
2409     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2410     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2411     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2412     TfLiteTensor* scratch6, TfLiteTensor* scratch7) {
2413   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2414   const int n_input = input->dims->data[input->dims->size - 1];
2415   int max_time, n_batch;
2416   if (input->dims->size == 2) {
2417     max_time = 1;
2418     n_batch = input->dims->data[0];
2419   } else {
2420     max_time = input->dims->data[0];
2421     n_batch = input->dims->data[1];
2422   }
2423 
2424   // n_cell and n_output will be the same size when there is no projection.
2425   const int n_cell = input_to_output_weights->dims->data[0];
2426   const int n_output = recurrent_to_output_weights->dims->data[1];
2427 
2428   const int32_t input_zp = input->params.zero_point;
2429   const int32_t output_state_zp = output_state->params.zero_point;
2430 
2431   // Get params for time/batch/sequence.
2432   const int output_batch_leading_dim =
2433       output->dims->data[output->dims->size - 1];
2434   const int input_step = n_batch * n_input;
2435   const int output_step = n_batch * output_batch_leading_dim;
2436 
2437   for (int t = 0; t < max_time; t++) {
2438     const int t_rel = t;
2439     int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2440     // Input can be int8 asymmetric or int16 symmetric.
2441     const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
2442     lstm_eval::LstmStepInteger8x8_8(
2443         input_ptr, input_zp,
2444 
2445         GetTensorData<int8_t>(input_to_input_weights),
2446         integer_lstm_param->effective_input_to_input_scale_a,
2447         integer_lstm_param->effective_input_to_input_scale_b,
2448 
2449         GetTensorData<int8_t>(input_to_forget_weights),
2450         integer_lstm_param->effective_input_to_forget_scale_a,
2451         integer_lstm_param->effective_input_to_forget_scale_b,
2452 
2453         GetTensorData<int8_t>(input_to_cell_weights),
2454         integer_lstm_param->effective_input_to_cell_scale_a,
2455         integer_lstm_param->effective_input_to_cell_scale_b,
2456 
2457         GetTensorData<int8_t>(input_to_output_weights),
2458         integer_lstm_param->effective_input_to_output_scale_a,
2459         integer_lstm_param->effective_input_to_output_scale_b,
2460 
2461         GetTensorData<int8_t>(recurrent_to_input_weights),
2462         integer_lstm_param->effective_recurrent_to_input_scale_a,
2463         integer_lstm_param->effective_recurrent_to_input_scale_b,
2464 
2465         GetTensorData<int8_t>(recurrent_to_forget_weights),
2466         integer_lstm_param->effective_recurrent_to_forget_scale_a,
2467         integer_lstm_param->effective_recurrent_to_forget_scale_b,
2468 
2469         GetTensorData<int8_t>(recurrent_to_cell_weights),
2470         integer_lstm_param->effective_recurrent_to_cell_scale_a,
2471         integer_lstm_param->effective_recurrent_to_cell_scale_b,
2472 
2473         GetTensorData<int8_t>(recurrent_to_output_weights),
2474         integer_lstm_param->effective_recurrent_to_output_scale_a,
2475         integer_lstm_param->effective_recurrent_to_output_scale_b,
2476 
2477         GetTensorData<int8_t>(cell_to_input_weights),
2478         integer_lstm_param->effective_cell_to_input_scale_a,
2479         integer_lstm_param->effective_cell_to_input_scale_b,
2480 
2481         GetTensorData<int8_t>(cell_to_forget_weights),
2482         integer_lstm_param->effective_cell_to_forget_scale_a,
2483         integer_lstm_param->effective_cell_to_forget_scale_b,
2484 
2485         GetTensorData<int8_t>(cell_to_output_weights),
2486         integer_lstm_param->effective_cell_to_output_scale_a,
2487         integer_lstm_param->effective_cell_to_output_scale_b,
2488 
2489         GetTensorData<int8_t>(projection_weights),
2490         integer_lstm_param->effective_proj_scale_a,
2491         integer_lstm_param->effective_proj_scale_b,
2492 
2493         GetTensorData<int16_t>(input_layer_norm_coefficients),
2494         integer_lstm_param->layer_norm_input_scale_a,
2495         integer_lstm_param->layer_norm_input_scale_b,
2496 
2497         GetTensorData<int16_t>(forget_layer_norm_coefficients),
2498         integer_lstm_param->layer_norm_forget_scale_a,
2499         integer_lstm_param->layer_norm_forget_scale_b,
2500 
2501         GetTensorData<int16_t>(cell_layer_norm_coefficients),
2502         integer_lstm_param->layer_norm_cell_scale_a,
2503         integer_lstm_param->layer_norm_cell_scale_b,
2504 
2505         GetTensorData<int16_t>(output_layer_norm_coefficients),
2506         integer_lstm_param->layer_norm_output_scale_a,
2507         integer_lstm_param->layer_norm_output_scale_b,
2508 
2509         GetTensorData<int32_t>(input_gate_bias),
2510         GetTensorData<int32_t>(forget_gate_bias),
2511         GetTensorData<int32_t>(cell_gate_bias),
2512         GetTensorData<int32_t>(output_gate_bias),
2513         GetTensorData<int32_t>(projection_bias),
2514 
2515         params, integer_lstm_param->intermediate_scale_a,
2516         integer_lstm_param->intermediate_scale_b,
2517         integer_lstm_param->intermediate_zp,
2518         integer_lstm_param->quantized_cell_clip,
2519         integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
2520         n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
2521         output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2522         GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
2523         GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2524         GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
2525         GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
2526   }
2527 
2528   return kTfLiteOk;
2529 }
2530 
2531 }  // namespace lstm_eval
2532 }  // namespace builtin
2533 }  // namespace ops
2534 }  // namespace tflite
2535