1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16
17 #include <cstdint>
18
19 #include "tensorflow/lite/c/c_api_internal.h"
20 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
21 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace lstm_eval {
28
29 namespace {
30
31 // Small float to avoid divergence during calculation of deviation for layer
32 // norm lstm.
33 const float kLayerNormEpsilon = 1e-8;
34
35 // Performs an LSTM batch inference step for input specified by input_ptr_batch.
36 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
37 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
38 // parameters:
39 // - params: various LSTM params including activation, clipping, etc.,
40 // - n_batch: size of batch,
41 // - n_cell: number of cells (or units),
42 // - n_input: the input size,
43 // - n_aux_input: the auxilary input size.
44 // - n_output: the output size.
45 // - output_batch_leading_dim: the leading dimension of the output buffer.
46 //
47 // LSTM weights:
48 // Input weights of size 'n_cell * n_input':
49 // input_to_input_weights - optional (can be nullptr)
50 // input_to_forget_weights
51 // input_to_cell_weights
52 // input_to_output_weights
53 // Auxilary input weights of size 'n_cell * n_aux_input':
54 // aux_input_to_input_weights - optional
55 // aux_input_to_forget_weights - optional
56 // aux_input_to_cell_weights - optional
57 // aux_input_to_output_weights - optional
58 // Recurrent weights of size 'n_cell * n_output':
59 // recurrent_to_input_weights - optional
60 // recurrent_to_forget_weights
61 // recurrent_to_cell_weights
62 // recurrent_to_input_weights
63 // Peephole weights of size 'n_cell', representing diagonal matrices.
64 // cell_to_input_weights - optional
65 // cell_to_cell_weights - optional
66 // cell_to_output_weights - optional
67 // Projection weights of size 'n_output * n_cell'
68 // projection_weights_ptr - optional
69 // Gate biases of size 'n_cell':
70 // input_gate_bias_ptr - optional
71 // forget_gate_bias_ptr
72 // cell_gate_bias_ptr
73 // output_gate_bias_ptr
74 //
75 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
76 // input_layer_norm_coefficients_ptr - optional
77 // forget_layer_norm_coefficients_ptr - optional
78 // cell_layer_norm_coefficients_ptr - optional
79 // output_layer_norm_coefficients_ptr - optional
80 //
81 // The pointers to the cell and output state and the output are updated.
82 //
83 // The pointers with the suffix "_batch" point to data aligned in batch_major
84 // order, and each step processes batch_size many inputs from input_ptr_batch,
85 // and updates batch_size many cell and output states.
86 //
87 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
88 // output tensor, and in most cases will be equal to n_output. It is usually not
89 // when we want to store the LSTM output into a slice of the output tensor, e.g.
90 // for bidirectional LSTMs with merge_outputs. In this case, the batched
91 // operations cannot be used since they assume that the batched outputs are
92 // contiguous, and we manually loop over the batched outputs.
LstmStepWithAuxInput(const float * input_ptr_batch,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr_batch,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * output_ptr_batch)93 inline void LstmStepWithAuxInput(
94 const float* input_ptr_batch, const float* input_to_input_weights_ptr,
95 const float* input_to_forget_weights_ptr,
96 const float* input_to_cell_weights_ptr,
97 const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
98 const float* aux_input_to_input_weights_ptr,
99 const float* aux_input_to_forget_weights_ptr,
100 const float* aux_input_to_cell_weights_ptr,
101 const float* aux_input_to_output_weights_ptr,
102 const float* recurrent_to_input_weights_ptr,
103 const float* recurrent_to_forget_weights_ptr,
104 const float* recurrent_to_cell_weights_ptr,
105 const float* recurrent_to_output_weights_ptr,
106 const float* cell_to_input_weights_ptr,
107 const float* cell_to_forget_weights_ptr,
108 const float* cell_to_output_weights_ptr,
109 const float* input_layer_norm_coefficients_ptr,
110 const float* forget_layer_norm_coefficients_ptr,
111 const float* cell_layer_norm_coefficients_ptr,
112 const float* output_layer_norm_coefficients_ptr,
113 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
114 const float* cell_bias_ptr, const float* output_gate_bias_ptr,
115 const float* projection_weights_ptr, const float* projection_bias_ptr,
116 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
117 int n_aux_input, int n_output, int output_batch_leading_dim,
118 float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
119 float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
120 float* output_ptr_batch) {
121 // Since we have already checked that weights are all there or none, we can
122 // check the existence of only one to the get the condition.
123 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
124 const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
125 const bool is_layer_norm_lstm =
126 (forget_layer_norm_coefficients_ptr != nullptr);
127
128 // Initialize scratch buffers with bias for regular lstm or initialize with
129 // zero for layer norm lstm.
130 if (is_layer_norm_lstm) {
131 if (!use_cifg) {
132 tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
133 }
134 tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
135 tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
136 tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
137 } else {
138 if (!use_cifg) {
139 tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
140 n_batch, input_gate_scratch);
141 }
142 tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
143 forget_gate_scratch);
144 tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
145 cell_scratch);
146 tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
147 output_gate_scratch);
148 }
149
150 // For each batch and cell: compute input_weight * input.
151 if (!use_cifg) {
152 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
153 input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
154 input_gate_scratch, /*result_stride=*/1);
155 }
156
157 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
158 input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
159 forget_gate_scratch, /*result_stride=*/1);
160 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
161 input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
162 cell_scratch, /*result_stride=*/1);
163 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
164 input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
165 output_gate_scratch, /*result_stride=*/1);
166
167 // If auxiliary input is available then compute aux_input_weight * aux_input
168 if (aux_input_ptr_batch != nullptr) {
169 if (!use_cifg) {
170 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
171 aux_input_to_input_weights_ptr, n_cell, n_aux_input,
172 aux_input_ptr_batch, n_batch, input_gate_scratch,
173 /*result_stride=*/1);
174 }
175
176 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
177 aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
178 aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
179 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
180 aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
181 n_batch, cell_scratch, /*result_stride=*/1);
182 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
183 aux_input_to_output_weights_ptr, n_cell, n_aux_input,
184 aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
185 }
186
187 // For each batch and cell: compute recurrent_weight * output_state.
188 if (!use_cifg) {
189 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
190 recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
191 n_batch, input_gate_scratch, /*result_stride=*/1);
192 }
193 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
194 recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
195 n_batch, forget_gate_scratch,
196 /*result_stride=*/1);
197 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
198 recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
199 n_batch, cell_scratch, /*result_stride=*/1);
200 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
201 recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
202 n_batch, output_gate_scratch,
203 /*result_stride=*/1);
204
205 // For each batch and cell: update input gate.
206 if (!use_cifg) {
207 if (use_peephole) {
208 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
209 cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
210 input_gate_scratch);
211 }
212 if (is_layer_norm_lstm) {
213 tensor_utils::MeanStddevNormalization(input_gate_scratch,
214 input_gate_scratch, n_cell, n_batch,
215 kLayerNormEpsilon);
216 tensor_utils::VectorBatchVectorCwiseProduct(
217 input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
218 n_batch, input_gate_scratch);
219 tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
220 input_gate_scratch);
221 }
222 tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
223 input_gate_scratch);
224 }
225
226 // For each batch and cell: update forget gate.
227 if (use_peephole) {
228 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
229 cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
230 forget_gate_scratch);
231 }
232 if (is_layer_norm_lstm) {
233 tensor_utils::MeanStddevNormalization(forget_gate_scratch,
234 forget_gate_scratch, n_cell, n_batch,
235 kLayerNormEpsilon);
236 tensor_utils::VectorBatchVectorCwiseProduct(
237 forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
238 n_batch, forget_gate_scratch);
239 tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
240 forget_gate_scratch);
241 }
242 tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
243 forget_gate_scratch);
244
245 // For each batch and cell: update the cell.
246 tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
247 n_batch * n_cell, cell_state_ptr);
248 if (is_layer_norm_lstm) {
249 tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
250 n_batch, kLayerNormEpsilon);
251 tensor_utils::VectorBatchVectorCwiseProduct(
252 cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
253 cell_scratch);
254 tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
255 cell_scratch);
256 }
257 tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
258 params->activation, cell_scratch);
259 if (use_cifg) {
260 tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
261 forget_gate_scratch);
262 tensor_utils::VectorVectorCwiseProductAccumulate(
263 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
264 } else {
265 tensor_utils::VectorVectorCwiseProductAccumulate(
266 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
267 }
268 if (params->cell_clip > 0.0) {
269 tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
270 params->cell_clip, cell_state_ptr);
271 }
272
273 // For each batch and cell: update the output gate.
274 if (use_peephole) {
275 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
276 cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
277 output_gate_scratch);
278 }
279 if (is_layer_norm_lstm) {
280 tensor_utils::MeanStddevNormalization(output_gate_scratch,
281 output_gate_scratch, n_cell, n_batch,
282 kLayerNormEpsilon);
283 tensor_utils::VectorBatchVectorCwiseProduct(
284 output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
285 n_batch, output_gate_scratch);
286 tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
287 output_gate_scratch);
288 }
289 tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
290 output_gate_scratch);
291 tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
292 params->activation, cell_scratch);
293 tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
294 n_batch * n_cell, output_gate_scratch);
295
296 const bool use_projection_weight = (projection_weights_ptr != nullptr);
297 const bool use_projection_bias = (projection_bias_ptr != nullptr);
298
299 // For each batch: update the projection and output_state. Note that since
300 // the output batch rows may not be contiguous (output_batch_leading_dim !=
301 // n_output), we unroll the batched operations where this is the case.
302 if (output_batch_leading_dim == n_output) {
303 if (use_projection_weight) {
304 if (use_projection_bias) {
305 tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
306 n_batch, output_ptr_batch);
307 } else {
308 tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
309 }
310 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
311 projection_weights_ptr, n_output, n_cell, output_gate_scratch,
312 n_batch, output_ptr_batch, /*result_stride=*/1);
313 if (params->proj_clip > 0.0) {
314 tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
315 params->proj_clip, output_ptr_batch);
316 }
317 } else {
318 tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
319 output_ptr_batch);
320 }
321 tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
322 output_state_ptr);
323 } else {
324 if (use_projection_weight) {
325 if (use_projection_bias) {
326 for (int k = 0; k < n_batch; k++) {
327 tensor_utils::CopyVector(
328 projection_bias_ptr, n_output,
329 output_ptr_batch + k * output_batch_leading_dim);
330 }
331 } else {
332 for (int k = 0; k < n_batch; k++) {
333 tensor_utils::ZeroVector(
334 output_ptr_batch + k * output_batch_leading_dim, n_output);
335 }
336 }
337 for (int k = 0; k < n_batch; k++) {
338 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
339 projection_weights_ptr, n_output, n_cell,
340 output_gate_scratch + k * n_cell,
341 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
342 /*result_stride=*/1);
343 if (params->proj_clip > 0.0) {
344 tensor_utils::ClipVector(
345 output_ptr_batch + k * output_batch_leading_dim, n_output,
346 params->proj_clip,
347 output_ptr_batch + k * output_batch_leading_dim);
348 }
349 }
350 } else {
351 for (int k = 0; k < n_batch; k++) {
352 tensor_utils::CopyVector(
353 output_gate_scratch + k * n_output, n_output,
354 output_ptr_batch + k * output_batch_leading_dim);
355 }
356 }
357 for (int k = 0; k < n_batch; k++) {
358 tensor_utils::CopyVector(output_ptr_batch + k * output_batch_leading_dim,
359 n_output, output_state_ptr + k * n_output);
360 }
361 }
362 }
363
364 // Same as above but with quantized weight matrices. In detail:
365 // Input of size 'n_batch * n_input':
366 // input_ptr_batch
367 //
368 // LSTM weights:
369 // Quantized input weights of size 'n_cell * n_input':
370 // input_to_input_weights - optional (can be nullptr)
371 // input_to_forget_weights
372 // input_to_cell_weights
373 // input_to_input_weights
374 // Quantized auxilary input weights of size 'n_cell * n_aux_input':
375 // aux_input_to_input_weights - optional
376 // aux_input_to_forget_weights - optional
377 // aux_input_to_cell_weights - optional
378 // aux_input_to_output_weights - optional
379 // Quantized recurrent weights of size 'n_cell * n_output':
380 // recurrent_to_input_weights - optional
381 // recurrent_to_forget_weights
382 // recurrent_to_cell_weights
383 // recurrent_to_input_weights
384 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
385 // cell_to_input_weights - optional
386 // cell_to_cell_weights - optional
387 // cell_to_output_weights - optional
388 // Quantized projection weights of size 'n_output * n_cell'
389 // projection_weights_ptr - optional
390 // Weight scales (scalars) for each of the weights above.
391 // input_to_input_weights_scale - optional
392 // input_to_forget_weights_scale
393 // input_to_cell_weights_scale
394 // input_to_output_weights_scale
395 // aux_input_to_input_weights_scale - optional
396 // aux_input_to_forget_weights_scale - optional
397 // aux_input_to_cell_weights_scale - optional
398 // aux_input_to_output_weights_scale - optional
399 // recurrent_to_input_weights_scale - optional
400 // recurrent_to_forget_weights_scale
401 // recurrent_to_cell_weights_scale
402 // recurrent_to_output_weights_scale
403 // cell_to_input_weights_scale,
404 // cell_to_forget_weights_scale,
405 // cell_to_output_weights_scale,
406 // projection_weights_scale - optional
407 // Gate biases of size 'n_cell':
408 // input_gate_bias_ptr - optional
409 // forget_gate_bias_ptr
410 // cell_gate_bias_ptr
411 // output_gate_bias_ptr
412 //
413 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
414 // input_layer_norm_coefficients_ptr - optional
415 // forget_layer_norm_coefficients_ptr - optional
416 // cell_layer_norm_coefficients_ptr - optional
417 // output_layer_norm_coefficients_ptr - optional
418 //
419 // Temporary pre-allocated storage for quantized values:
420 // quantized_input_ptr_batch (same size as input_ptr_batch)
421 // quantized_output_state_ptr (same size as output_state_ptr)
422 // quantized_cell_state_ptr (same size as cell_state_ptr)
423 // Temporary pre-allocated storage for recovered values:
424 // recovered_cell_weights (same size as cell_to_*_weights)
425 //
426 // Outputs:
427 // output_state_ptr - size 'n_batch * n_output'
428 // cell_state_ptr - size 'n_batch * n_cell'
429 // output_ptr_batch - size 'n_batch * output_batch_leading_dim'
LstmStepWithAuxInput(const float * input_ptr_batch,const int8_t * input_to_input_weights_ptr,float input_to_input_weights_scale,const int8_t * input_to_forget_weights_ptr,float input_to_forget_weights_scale,const int8_t * input_to_cell_weights_ptr,float input_to_cell_weights_scale,const int8_t * input_to_output_weights_ptr,float input_to_output_weights_scale,const float * aux_input_ptr_batch,const int8_t * aux_input_to_input_weights_ptr,float aux_input_to_input_weights_scale,const int8_t * aux_input_to_forget_weights_ptr,float aux_input_to_forget_weights_scale,const int8_t * aux_input_to_cell_weights_ptr,float aux_input_to_cell_weights_scale,const int8_t * aux_input_to_output_weights_ptr,float aux_input_to_output_weights_scale,const int8_t * recurrent_to_input_weights_ptr,float recurrent_to_input_weights_scale,const int8_t * recurrent_to_forget_weights_ptr,float recurrent_to_forget_weights_scale,const int8_t * recurrent_to_cell_weights_ptr,float recurrent_to_cell_weights_scale,const int8_t * recurrent_to_output_weights_ptr,float recurrent_to_output_weights_scale,const int8_t * cell_to_input_weights_ptr,float cell_to_input_weights_scale,const int8_t * cell_to_forget_weights_ptr,float cell_to_forget_weights_scale,const int8_t * cell_to_output_weights_ptr,float cell_to_output_weights_scale,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_bias_ptr,const float * output_gate_bias_ptr,const int8_t * projection_weights_ptr,float projection_weights_scale,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * input_gate_scratch,float * forget_gate_scratch,float * cell_scratch,float * output_gate_scratch,float * scaling_factors,float * product_scaling_factors,float * recovered_cell_weights,int8_t * quantized_input_ptr_batch,int8_t * quantized_aux_input_ptr_batch,int8_t * quantized_output_state_ptr,int8_t * quantized_cell_state_ptr,float * output_state_ptr,float * cell_state_ptr,float * output_ptr_batch)430 inline void LstmStepWithAuxInput(
431 const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
432 float input_to_input_weights_scale,
433 const int8_t* input_to_forget_weights_ptr,
434 float input_to_forget_weights_scale,
435 const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
436 const int8_t* input_to_output_weights_ptr,
437 float input_to_output_weights_scale, const float* aux_input_ptr_batch,
438 const int8_t* aux_input_to_input_weights_ptr,
439 float aux_input_to_input_weights_scale,
440 const int8_t* aux_input_to_forget_weights_ptr,
441 float aux_input_to_forget_weights_scale,
442 const int8_t* aux_input_to_cell_weights_ptr,
443 float aux_input_to_cell_weights_scale,
444 const int8_t* aux_input_to_output_weights_ptr,
445 float aux_input_to_output_weights_scale,
446 const int8_t* recurrent_to_input_weights_ptr,
447 float recurrent_to_input_weights_scale,
448 const int8_t* recurrent_to_forget_weights_ptr,
449 float recurrent_to_forget_weights_scale,
450 const int8_t* recurrent_to_cell_weights_ptr,
451 float recurrent_to_cell_weights_scale,
452 const int8_t* recurrent_to_output_weights_ptr,
453 float recurrent_to_output_weights_scale,
454 const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
455 const int8_t* cell_to_forget_weights_ptr,
456 float cell_to_forget_weights_scale,
457 const int8_t* cell_to_output_weights_ptr,
458 float cell_to_output_weights_scale,
459 const float* input_layer_norm_coefficients_ptr,
460 const float* forget_layer_norm_coefficients_ptr,
461 const float* cell_layer_norm_coefficients_ptr,
462 const float* output_layer_norm_coefficients_ptr,
463 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
464 const float* cell_bias_ptr, const float* output_gate_bias_ptr,
465 const int8_t* projection_weights_ptr, float projection_weights_scale,
466 const float* projection_bias_ptr, const TfLiteLSTMParams* params,
467 int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
468 int output_batch_leading_dim, float* input_gate_scratch,
469 float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
470 float* scaling_factors, float* product_scaling_factors,
471 float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
472 int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
473 int8_t* quantized_cell_state_ptr, float* output_state_ptr,
474 float* cell_state_ptr, float* output_ptr_batch) {
475 // Since we have already checked that weights are all there or none, we
476 // can check the existence of only one to the get the condition.
477 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
478 const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
479 const bool is_layer_norm_lstm =
480 (forget_layer_norm_coefficients_ptr != nullptr);
481
482 // Initialize scratch buffers with bias.
483 if (is_layer_norm_lstm) {
484 if (!use_cifg) {
485 tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
486 }
487 tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
488 tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
489 tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
490 } else {
491 if (!use_cifg) {
492 tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
493 n_batch, input_gate_scratch);
494 }
495 tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
496 forget_gate_scratch);
497 tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
498 cell_scratch);
499 tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
500 output_gate_scratch);
501 }
502
503 if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
504 // Save quantization and matmul computation for all zero input.
505 float unused_min, unused_max;
506 for (int b = 0; b < n_batch; ++b) {
507 const int offset = b * n_input;
508 tensor_utils::SymmetricQuantizeFloats(
509 input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
510 &unused_min, &unused_max, &scaling_factors[b]);
511 }
512 // For each batch and cell: compute input_weight * input.
513 if (!use_cifg) {
514 for (int b = 0; b < n_batch; ++b) {
515 product_scaling_factors[b] =
516 scaling_factors[b] * input_to_input_weights_scale;
517 }
518 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
519 input_to_input_weights_ptr, n_cell, n_input,
520 quantized_input_ptr_batch, product_scaling_factors, n_batch,
521 input_gate_scratch, /*result_stride=*/1);
522 }
523
524 for (int b = 0; b < n_batch; ++b) {
525 product_scaling_factors[b] =
526 scaling_factors[b] * input_to_forget_weights_scale;
527 }
528 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
529 input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
530 product_scaling_factors, n_batch, forget_gate_scratch,
531 /*result_stride=*/1);
532
533 for (int b = 0; b < n_batch; ++b) {
534 product_scaling_factors[b] =
535 scaling_factors[b] * input_to_cell_weights_scale;
536 }
537 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
538 input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
539 product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
540
541 for (int b = 0; b < n_batch; ++b) {
542 product_scaling_factors[b] =
543 scaling_factors[b] * input_to_output_weights_scale;
544 }
545 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
546 input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
547 product_scaling_factors, n_batch, output_gate_scratch,
548 /*result_stride=*/1);
549 }
550
551 if (aux_input_ptr_batch != nullptr &&
552 !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
553 // Save quantization and matmul computation for all zero input.
554 float unused_min, unused_max;
555 for (int b = 0; b < n_batch; ++b) {
556 const int offset = b * n_input;
557 tensor_utils::SymmetricQuantizeFloats(
558 aux_input_ptr_batch + offset, n_input,
559 quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
560 &scaling_factors[b]);
561 }
562 // For each batch and cell: compute input_weight * input.
563 if (!use_cifg) {
564 for (int b = 0; b < n_batch; ++b) {
565 product_scaling_factors[b] =
566 scaling_factors[b] * aux_input_to_input_weights_scale;
567 }
568 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
569 aux_input_to_input_weights_ptr, n_cell, n_input,
570 quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
571 input_gate_scratch, /*result_stride=*/1);
572 }
573
574 for (int b = 0; b < n_batch; ++b) {
575 product_scaling_factors[b] =
576 scaling_factors[b] * aux_input_to_forget_weights_scale;
577 }
578 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
579 aux_input_to_forget_weights_ptr, n_cell, n_input,
580 quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
581 forget_gate_scratch, /*result_stride=*/1);
582
583 for (int b = 0; b < n_batch; ++b) {
584 product_scaling_factors[b] =
585 scaling_factors[b] * aux_input_to_cell_weights_scale;
586 }
587 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
588 aux_input_to_cell_weights_ptr, n_cell, n_input,
589 quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
590 cell_scratch, /*result_stride=*/1);
591
592 for (int b = 0; b < n_batch; ++b) {
593 product_scaling_factors[b] =
594 scaling_factors[b] * aux_input_to_output_weights_scale;
595 }
596 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
597 aux_input_to_output_weights_ptr, n_cell, n_input,
598 quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
599 output_gate_scratch, /*result_stride=*/1);
600 }
601
602 if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
603 // Save quantization and matmul computation for all zero input.
604 float unused_min, unused_max;
605 for (int b = 0; b < n_batch; ++b) {
606 const int offset = b * n_output;
607 tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
608 quantized_output_state_ptr + offset,
609 &unused_min, &unused_max,
610 &scaling_factors[b]);
611 }
612 // For each batch and cell: compute recurrent_weight * output_state.
613 if (!use_cifg) {
614 for (int b = 0; b < n_batch; ++b) {
615 product_scaling_factors[b] =
616 scaling_factors[b] * recurrent_to_input_weights_scale;
617 }
618 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
619 recurrent_to_input_weights_ptr, n_cell, n_output,
620 quantized_output_state_ptr, product_scaling_factors, n_batch,
621 input_gate_scratch, /*result_stride=*/1);
622 }
623
624 for (int b = 0; b < n_batch; ++b) {
625 product_scaling_factors[b] =
626 scaling_factors[b] * recurrent_to_forget_weights_scale;
627 }
628 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
629 recurrent_to_forget_weights_ptr, n_cell, n_output,
630 quantized_output_state_ptr, product_scaling_factors, n_batch,
631 forget_gate_scratch, /*result_stride=*/1);
632
633 for (int b = 0; b < n_batch; ++b) {
634 product_scaling_factors[b] =
635 scaling_factors[b] * recurrent_to_cell_weights_scale;
636 }
637 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
638 recurrent_to_cell_weights_ptr, n_cell, n_output,
639 quantized_output_state_ptr, product_scaling_factors, n_batch,
640 cell_scratch, /*result_stride=*/1);
641
642 for (int b = 0; b < n_batch; ++b) {
643 product_scaling_factors[b] =
644 scaling_factors[b] * recurrent_to_output_weights_scale;
645 }
646 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
647 recurrent_to_output_weights_ptr, n_cell, n_output,
648 quantized_output_state_ptr, product_scaling_factors, n_batch,
649 output_gate_scratch, /*result_stride=*/1);
650 }
651
652 // Save quantization and matmul computation for all zero input.
653 bool is_cell_state_all_zeros =
654 tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
655
656 // For each batch and cell: update input gate.
657 if (!use_cifg) {
658 if (use_peephole && !is_cell_state_all_zeros) {
659 tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
660 cell_to_input_weights_scale,
661 recovered_cell_weights);
662 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
663 recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
664 input_gate_scratch);
665 }
666 if (is_layer_norm_lstm) {
667 tensor_utils::MeanStddevNormalization(input_gate_scratch,
668 input_gate_scratch, n_cell, n_batch,
669 kLayerNormEpsilon);
670 tensor_utils::VectorBatchVectorCwiseProduct(
671 input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
672 n_batch, input_gate_scratch);
673 tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
674 input_gate_scratch);
675 }
676 tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
677 input_gate_scratch);
678 }
679
680 // For each batch and cell: update forget gate.
681 if (use_peephole && !is_cell_state_all_zeros) {
682 tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
683 cell_to_forget_weights_scale,
684 recovered_cell_weights);
685 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
686 recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
687 forget_gate_scratch);
688 }
689 if (is_layer_norm_lstm) {
690 tensor_utils::MeanStddevNormalization(forget_gate_scratch,
691 forget_gate_scratch, n_cell, n_batch,
692 kLayerNormEpsilon);
693 tensor_utils::VectorBatchVectorCwiseProduct(
694 forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
695 n_batch, forget_gate_scratch);
696 tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
697 forget_gate_scratch);
698 }
699 tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
700 forget_gate_scratch);
701
702 // For each batch and cell: update the cell.
703 tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
704 n_batch * n_cell, cell_state_ptr);
705 if (is_layer_norm_lstm) {
706 tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
707 n_batch, kLayerNormEpsilon);
708 tensor_utils::VectorBatchVectorCwiseProduct(
709 cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
710 cell_scratch);
711 tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
712 cell_scratch);
713 }
714 tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
715 params->activation, cell_scratch);
716 if (use_cifg) {
717 tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
718 forget_gate_scratch);
719 tensor_utils::VectorVectorCwiseProductAccumulate(
720 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
721 } else {
722 tensor_utils::VectorVectorCwiseProductAccumulate(
723 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
724 }
725 if (params->cell_clip > 0.0) {
726 tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
727 params->cell_clip, cell_state_ptr);
728 }
729
730 is_cell_state_all_zeros =
731 tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
732 // For each batch and cell: update the output gate.
733 if (use_peephole && !is_cell_state_all_zeros) {
734 tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
735 cell_to_output_weights_scale,
736 recovered_cell_weights);
737 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
738 recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
739 output_gate_scratch);
740 }
741 if (is_layer_norm_lstm) {
742 tensor_utils::MeanStddevNormalization(output_gate_scratch,
743 output_gate_scratch, n_cell, n_batch,
744 kLayerNormEpsilon);
745 tensor_utils::VectorBatchVectorCwiseProduct(
746 output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
747 n_batch, output_gate_scratch);
748 tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
749 output_gate_scratch);
750 }
751 tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
752 output_gate_scratch);
753 tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
754 params->activation, cell_scratch);
755 tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
756 n_batch * n_cell, output_gate_scratch);
757
758 const bool use_projection_weight = (projection_weights_ptr != nullptr);
759 const bool use_projection_bias = (projection_bias_ptr != nullptr);
760
761 // For each batch: update the projection and output_state. Note that since
762 // the output batch rows may not be contiguous (output_batch_leading_dim !=
763 // n_output), we unroll the batched operations where this is the case.
764 if (output_batch_leading_dim == n_output) {
765 if (use_projection_weight) {
766 if (use_projection_bias) {
767 tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
768 n_batch, output_ptr_batch);
769 } else {
770 tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
771 }
772 if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
773 // Save quantization and matmul computation for all zero input.
774 float unused_min, unused_max;
775 for (int b = 0; b < n_batch; ++b) {
776 const int offset = b * n_cell;
777 tensor_utils::SymmetricQuantizeFloats(
778 output_gate_scratch + offset, n_cell,
779 quantized_cell_state_ptr + offset, &unused_min, &unused_max,
780 &scaling_factors[b]);
781 }
782 for (int b = 0; b < n_batch; ++b) {
783 product_scaling_factors[b] =
784 scaling_factors[b] * projection_weights_scale;
785 }
786 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
787 projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
788 product_scaling_factors, n_batch, output_ptr_batch,
789 /*result_stride=*/1);
790 }
791 if (params->proj_clip > 0.0) {
792 tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
793 params->proj_clip, output_ptr_batch);
794 }
795 } else {
796 tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
797 output_ptr_batch);
798 }
799 tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
800 output_state_ptr);
801 } else {
802 if (use_projection_weight) {
803 if (use_projection_bias) {
804 for (int k = 0; k < n_batch; k++) {
805 tensor_utils::CopyVector(
806 projection_bias_ptr, n_output,
807 output_ptr_batch + k * output_batch_leading_dim);
808 }
809 } else {
810 for (int k = 0; k < n_batch; k++) {
811 tensor_utils::ZeroVector(
812 output_ptr_batch + k * output_batch_leading_dim, n_output);
813 }
814 }
815 if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
816 // Save quantization and matmul computation for all zero input.
817 float unused_min, unused_max;
818 for (int b = 0; b < n_batch; ++b) {
819 const int offset = b * n_cell;
820 tensor_utils::SymmetricQuantizeFloats(
821 output_gate_scratch + offset, n_cell,
822 quantized_cell_state_ptr + offset, &unused_min, &unused_max,
823 &scaling_factors[b]);
824 }
825 for (int b = 0; b < n_batch; ++b) {
826 product_scaling_factors[b] =
827 scaling_factors[b] * projection_weights_scale;
828 }
829 for (int k = 0; k < n_batch; k++) {
830 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
831 projection_weights_ptr, n_output, n_cell,
832 quantized_cell_state_ptr + k * n_cell,
833 &product_scaling_factors[k],
834 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
835 /*result_stride=*/1);
836 }
837 }
838 if (params->proj_clip > 0.0) {
839 for (int k = 0; k < n_batch; k++) {
840 tensor_utils::ClipVector(
841 output_ptr_batch + k * output_batch_leading_dim, n_output,
842 params->proj_clip,
843 output_ptr_batch + k * output_batch_leading_dim);
844 }
845 }
846 } else {
847 for (int k = 0; k < n_batch; k++) {
848 tensor_utils::CopyVector(
849 output_gate_scratch + k * n_output, n_output,
850 output_ptr_batch + k * output_batch_leading_dim);
851 }
852 }
853 for (int k = 0; k < n_batch; k++) {
854 tensor_utils::CopyVector(output_ptr_batch + k * output_batch_leading_dim,
855 n_output, output_state_ptr + k * n_output);
856 }
857 }
858 }
859
GetInt8DataPtr(const TfLiteTensor * tensor,const bool is_uint8)860 int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
861 if (is_uint8) {
862 return reinterpret_cast<int8_t*>(tensor->data.uint8);
863 } else {
864 return tensor->data.int8;
865 }
866 }
867
868 } // namespace
869
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * activation_state,TfLiteTensor * cell_state,TfLiteTensor * output)870 TfLiteStatus EvalFloat(
871 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
872 const TfLiteTensor* input_to_forget_weights,
873 const TfLiteTensor* input_to_cell_weights,
874 const TfLiteTensor* input_to_output_weights,
875 const TfLiteTensor* recurrent_to_input_weights,
876 const TfLiteTensor* recurrent_to_forget_weights,
877 const TfLiteTensor* recurrent_to_cell_weights,
878 const TfLiteTensor* recurrent_to_output_weights,
879 const TfLiteTensor* cell_to_input_weights,
880 const TfLiteTensor* cell_to_forget_weights,
881 const TfLiteTensor* cell_to_output_weights,
882 const TfLiteTensor* input_layer_norm_coefficients,
883 const TfLiteTensor* forget_layer_norm_coefficients,
884 const TfLiteTensor* cell_layer_norm_coefficients,
885 const TfLiteTensor* output_layer_norm_coefficients,
886 const TfLiteTensor* aux_input,
887 const TfLiteTensor* aux_input_to_input_weights,
888 const TfLiteTensor* aux_input_to_forget_weights,
889 const TfLiteTensor* aux_input_to_cell_weights,
890 const TfLiteTensor* aux_input_to_output_weights,
891 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
892 const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
893 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
894 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
895 int output_offset, TfLiteTensor* scratch_buffer,
896 TfLiteTensor* activation_state, TfLiteTensor* cell_state,
897 TfLiteTensor* output) {
898 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
899 int max_time, n_batch;
900 if (input->dims->size == 3) {
901 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
902 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
903 } else {
904 max_time = 1;
905 n_batch = input->dims->data[0];
906 }
907 const int n_input = input->dims->data[input->dims->size - 1];
908 const int aux_input_size =
909 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
910
911 // n_cell and n_output will be the same size when there is no projection.
912 const int n_cell = input_to_output_weights->dims->data[0];
913 const int n_output = recurrent_to_output_weights->dims->data[1];
914
915 // Since we have already checked that weights are all there or none, we can
916 // check the existence of only one to the get the condition.
917 const bool use_cifg = (input_to_input_weights == nullptr);
918 const bool use_peephole = (cell_to_output_weights != nullptr);
919 const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
920
921 // Index the scratch buffers pointers to the global scratch buffer.
922 float* input_gate_scratch = nullptr;
923 float* cell_scratch = nullptr;
924 float* forget_gate_scratch = nullptr;
925 float* output_gate_scratch = nullptr;
926 if (use_cifg) {
927 cell_scratch = scratch_buffer->data.f;
928 forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
929 output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
930 } else {
931 input_gate_scratch = scratch_buffer->data.f;
932 cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
933 forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
934 output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
935 }
936
937 // Check optional tensors, the respective pointers can be null.
938 const float* input_to_input_weights_ptr =
939 (use_cifg) ? nullptr : input_to_input_weights->data.f;
940 const float* recurrent_to_input_weights_ptr =
941 (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
942 const float* input_gate_bias_ptr =
943 (use_cifg) ? nullptr : input_gate_bias->data.f;
944 const float* cell_to_input_weights_ptr =
945 (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
946 const float* cell_to_forget_weights_ptr =
947 (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
948 const float* cell_to_output_weights_ptr =
949 (use_peephole) ? cell_to_output_weights->data.f : nullptr;
950 const float* input_layer_norm_coefficients_ptr =
951 (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
952 : nullptr;
953 const float* forget_layer_norm_coefficients_ptr =
954 is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
955 const float* cell_layer_norm_coefficients_ptr =
956 is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
957 const float* output_layer_norm_coefficients_ptr =
958 is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
959 const float* projection_weights_ptr =
960 (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
961 const float* projection_bias_ptr =
962 (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
963
964 float* aux_input_ptr = nullptr;
965 float* aux_input_to_input_weights_ptr = nullptr;
966 float* aux_input_to_forget_weights_ptr = nullptr;
967 float* aux_input_to_cell_weights_ptr = nullptr;
968 float* aux_input_to_output_weights_ptr = nullptr;
969 if (aux_input_size > 0) {
970 if (!use_cifg) {
971 aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
972 }
973 aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
974 aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
975 aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
976 }
977
978 const int output_batch_leading_dim =
979 output->dims->data[output->dims->size - 1];
980 if (time_major) {
981 // Loop through the sequence.
982 const int input_step = n_batch * n_input;
983 const int output_step = n_batch * output_batch_leading_dim;
984 for (int t = 0; t < max_time; t++) {
985 // If this is the forward_sequence, step forward, otherwise step
986 // backwards.
987 const int t_rel = forward_sequence ? t : max_time - t - 1;
988 const float* input_ptr_batch = input->data.f + t_rel * input_step;
989 if (aux_input) {
990 aux_input_ptr = aux_input->data.f + t_rel * input_step;
991 }
992 float* output_ptr_time =
993 output->data.f + t_rel * output_step + output_offset;
994
995 LstmStepWithAuxInput(
996 input_ptr_batch, input_to_input_weights_ptr,
997 input_to_forget_weights->data.f, input_to_cell_weights->data.f,
998 input_to_output_weights->data.f, aux_input_ptr,
999 aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1000 aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1001 recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
1002 recurrent_to_cell_weights->data.f,
1003 recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
1004 cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
1005 input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
1006 cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
1007 input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
1008 output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
1009 params, n_batch, n_cell, n_input, aux_input_size, n_output,
1010 output_batch_leading_dim, activation_state->data.f,
1011 cell_state->data.f, input_gate_scratch, forget_gate_scratch,
1012 cell_scratch, output_gate_scratch, output_ptr_time);
1013 }
1014 } else {
1015 for (int b = 0; b < n_batch; b++) {
1016 const int input_step = n_input;
1017 const int output_step = output_batch_leading_dim;
1018 for (int t = 0; t < max_time; t++) {
1019 // If this is the forward_sequence, step forward, otherwise step
1020 // backwards.
1021 const int t_rel = forward_sequence ? t : max_time - t - 1;
1022 const int time_offset = b * max_time + t_rel;
1023 const float* input_ptr = input->data.f + time_offset * input_step;
1024 if (aux_input) {
1025 aux_input_ptr = aux_input->data.f + time_offset * input_step;
1026 }
1027 float* output_ptr =
1028 output->data.f + time_offset * output_step + output_offset;
1029
1030 // Offset the {activation,cell}_state pointers to the right batch.
1031 float* activation_state_ptr =
1032 activation_state->data.f + b * output_batch_leading_dim;
1033 float* cell_state_ptr = cell_state->data.f + b * n_cell;
1034 // Offset the scratch pointers to the right batch.
1035 float* input_gate_scratch_ptr =
1036 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1037 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1038 float* cell_scratch_ptr = cell_scratch + b * n_cell;
1039 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1040
1041 LstmStepWithAuxInput(
1042 input_ptr, input_to_input_weights_ptr,
1043 input_to_forget_weights->data.f, input_to_cell_weights->data.f,
1044 input_to_output_weights->data.f, aux_input_ptr,
1045 aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1046 aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1047 recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
1048 recurrent_to_cell_weights->data.f,
1049 recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
1050 cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
1051 input_layer_norm_coefficients_ptr,
1052 forget_layer_norm_coefficients_ptr,
1053 cell_layer_norm_coefficients_ptr,
1054 output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
1055 forget_gate_bias->data.f, cell_bias->data.f,
1056 output_gate_bias->data.f, projection_weights_ptr,
1057 projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
1058 aux_input_size, n_output, output_batch_leading_dim,
1059 activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1060 forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
1061 output_ptr);
1062 }
1063 }
1064 }
1065 return kTfLiteOk;
1066 }
1067
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * scaling_factors,TfLiteTensor * prod_scaling_factors,TfLiteTensor * recovered_cell_weights,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * output_state_quantized,TfLiteTensor * cell_state_quantized,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output)1068 TfLiteStatus EvalHybrid(
1069 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1070 const TfLiteTensor* input_to_forget_weights,
1071 const TfLiteTensor* input_to_cell_weights,
1072 const TfLiteTensor* input_to_output_weights,
1073 const TfLiteTensor* recurrent_to_input_weights,
1074 const TfLiteTensor* recurrent_to_forget_weights,
1075 const TfLiteTensor* recurrent_to_cell_weights,
1076 const TfLiteTensor* recurrent_to_output_weights,
1077 const TfLiteTensor* cell_to_input_weights,
1078 const TfLiteTensor* cell_to_forget_weights,
1079 const TfLiteTensor* cell_to_output_weights,
1080 const TfLiteTensor* input_layer_norm_coefficients,
1081 const TfLiteTensor* forget_layer_norm_coefficients,
1082 const TfLiteTensor* cell_layer_norm_coefficients,
1083 const TfLiteTensor* output_layer_norm_coefficients,
1084 const TfLiteTensor* aux_input,
1085 const TfLiteTensor* aux_input_to_input_weights,
1086 const TfLiteTensor* aux_input_to_forget_weights,
1087 const TfLiteTensor* aux_input_to_cell_weights,
1088 const TfLiteTensor* aux_input_to_output_weights,
1089 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1090 const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
1091 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1092 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1093 int output_offset, TfLiteTensor* scratch_buffer,
1094 TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
1095 TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
1096 TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
1097 TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
1098 TfLiteTensor* cell_state, TfLiteTensor* output) {
1099 // For operations that use int8 instead of uint8 we need to fetch raw data
1100 // from the tensor different. We use this bool for that condition.
1101 const bool is_uint8_hybrid = input_to_output_weights->type == kTfLiteUInt8;
1102 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1103 const int n_input = input->dims->data[input->dims->size - 1];
1104 int max_time, n_batch;
1105 if (input->dims->size == 2) {
1106 max_time = 1;
1107 n_batch = input->dims->data[0];
1108 } else {
1109 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1110 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1111 }
1112 const int aux_input_size =
1113 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1114 // n_cell and n_output will be the same size when there is no projection.
1115 const int n_cell = input_to_output_weights->dims->data[0];
1116 const int n_output = recurrent_to_output_weights->dims->data[1];
1117
1118 // Since we have already checked that weights are all there or none, we can
1119 // check the existence of only one to get the condition.
1120 const bool use_cifg = (input_to_input_weights == nullptr);
1121 const bool use_peephole = (cell_to_output_weights != nullptr);
1122 const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
1123
1124 float* input_gate_scratch = nullptr;
1125 float* cell_scratch = nullptr;
1126 float* forget_gate_scratch = nullptr;
1127 float* output_gate_scratch = nullptr;
1128 if (use_cifg) {
1129 cell_scratch = scratch_buffer->data.f;
1130 forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
1131 output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
1132 } else {
1133 input_gate_scratch = scratch_buffer->data.f;
1134 cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
1135 forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
1136 output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
1137 }
1138
1139 // Check optional tensors, the respective pointers can be null.
1140 int8_t* input_to_input_weights_ptr = nullptr;
1141 float input_to_input_weights_scale = 1.0f;
1142 int8_t* recurrent_to_input_weights_ptr = nullptr;
1143 float recurrent_to_input_weights_scale = 1.0f;
1144 float* input_gate_bias_ptr = nullptr;
1145 if (!use_cifg) {
1146 input_to_input_weights_ptr =
1147 GetInt8DataPtr(input_to_input_weights, is_uint8_hybrid);
1148 recurrent_to_input_weights_ptr =
1149 GetInt8DataPtr(recurrent_to_input_weights, is_uint8_hybrid);
1150 input_gate_bias_ptr = input_gate_bias->data.f;
1151 input_to_input_weights_scale = input_to_input_weights->params.scale;
1152 recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
1153 }
1154
1155 int8_t* cell_to_input_weights_ptr = nullptr;
1156 int8_t* cell_to_forget_weights_ptr = nullptr;
1157 int8_t* cell_to_output_weights_ptr = nullptr;
1158 float cell_to_input_weights_scale = 1.0f;
1159 float cell_to_forget_weights_scale = 1.0f;
1160 float cell_to_output_weights_scale = 1.0f;
1161 if (use_peephole) {
1162 if (!use_cifg) {
1163 cell_to_input_weights_ptr =
1164 GetInt8DataPtr(cell_to_input_weights, is_uint8_hybrid);
1165 cell_to_input_weights_scale = cell_to_input_weights->params.scale;
1166 }
1167 cell_to_forget_weights_ptr =
1168 GetInt8DataPtr(cell_to_forget_weights, is_uint8_hybrid);
1169 cell_to_output_weights_ptr =
1170 GetInt8DataPtr(cell_to_output_weights, is_uint8_hybrid);
1171 cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
1172 cell_to_output_weights_scale = cell_to_output_weights->params.scale;
1173 }
1174
1175 const float* input_layer_norm_coefficients_ptr =
1176 (is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
1177 : nullptr;
1178 const float* forget_layer_norm_coefficients_ptr =
1179 is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
1180 const float* cell_layer_norm_coefficients_ptr =
1181 is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
1182 const float* output_layer_norm_coefficients_ptr =
1183 is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
1184
1185 const int8_t* projection_weights_ptr =
1186 (projection_weights == nullptr)
1187 ? nullptr
1188 : GetInt8DataPtr(projection_weights, is_uint8_hybrid);
1189 const float projection_weights_scale =
1190 (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
1191 const float* projection_bias_ptr =
1192 (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
1193
1194 // Required tensors, pointers are non-null.
1195 const int8_t* input_to_forget_weights_ptr =
1196 GetInt8DataPtr(input_to_forget_weights, is_uint8_hybrid);
1197 const float input_to_forget_weights_scale =
1198 input_to_forget_weights->params.scale;
1199 const int8_t* input_to_cell_weights_ptr =
1200 GetInt8DataPtr(input_to_cell_weights, is_uint8_hybrid);
1201 const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
1202 const int8_t* input_to_output_weights_ptr =
1203 GetInt8DataPtr(input_to_output_weights, is_uint8_hybrid);
1204 const float input_to_output_weights_scale =
1205 input_to_output_weights->params.scale;
1206 const int8_t* recurrent_to_forget_weights_ptr =
1207 GetInt8DataPtr(recurrent_to_forget_weights, is_uint8_hybrid);
1208 const float recurrent_to_forget_weights_scale =
1209 recurrent_to_forget_weights->params.scale;
1210 const int8_t* recurrent_to_cell_weights_ptr =
1211 GetInt8DataPtr(recurrent_to_cell_weights, is_uint8_hybrid);
1212 const float recurrent_to_cell_weights_scale =
1213 recurrent_to_cell_weights->params.scale;
1214 const int8_t* recurrent_to_output_weights_ptr =
1215 GetInt8DataPtr(recurrent_to_output_weights, is_uint8_hybrid);
1216 const float recurrent_to_output_weights_scale =
1217 recurrent_to_output_weights->params.scale;
1218 const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
1219 const float* cell_bias_ptr = cell_bias->data.f;
1220 const float* output_gate_bias_ptr = output_gate_bias->data.f;
1221
1222 // Temporary storage for quantized values and scaling factors.
1223 int8_t* quantized_input_ptr =
1224 GetInt8DataPtr(input_quantized, is_uint8_hybrid);
1225 int8_t* quantized_aux_input_ptr =
1226 (aux_input_quantized == nullptr)
1227 ? nullptr
1228 : GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid);
1229 int8_t* quantized_output_state_ptr =
1230 GetInt8DataPtr(output_state_quantized, is_uint8_hybrid);
1231 int8_t* quantized_cell_state_ptr =
1232 GetInt8DataPtr(cell_state_quantized, is_uint8_hybrid);
1233 float* scaling_factors_ptr = scaling_factors->data.f;
1234 float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
1235 float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
1236
1237 // Auxiliary input and weights.
1238 float* aux_input_ptr = nullptr;
1239 int8_t* aux_input_to_input_weights_ptr = nullptr;
1240 int8_t* aux_input_to_forget_weights_ptr = nullptr;
1241 int8_t* aux_input_to_cell_weights_ptr = nullptr;
1242 int8_t* aux_input_to_output_weights_ptr = nullptr;
1243 float aux_input_to_input_weights_scale = 0.0f;
1244 float aux_input_to_forget_weights_scale = 0.0f;
1245 float aux_input_to_cell_weights_scale = 0.0f;
1246 float aux_input_to_output_weights_scale = 0.0f;
1247 if (aux_input_size > 0) {
1248 if (!use_cifg) {
1249 aux_input_to_input_weights_ptr =
1250 GetInt8DataPtr(aux_input_to_input_weights, is_uint8_hybrid);
1251 }
1252 aux_input_to_forget_weights_ptr =
1253 GetInt8DataPtr(aux_input_to_forget_weights, is_uint8_hybrid);
1254 aux_input_to_cell_weights_ptr =
1255 GetInt8DataPtr(aux_input_to_cell_weights, is_uint8_hybrid);
1256 aux_input_to_output_weights_ptr =
1257 GetInt8DataPtr(aux_input_to_output_weights, is_uint8_hybrid);
1258 if (!use_cifg) {
1259 aux_input_to_input_weights_scale =
1260 aux_input_to_input_weights->params.scale;
1261 }
1262 aux_input_to_forget_weights_scale =
1263 aux_input_to_forget_weights->params.scale;
1264 aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
1265 aux_input_to_output_weights_scale =
1266 aux_input_to_output_weights->params.scale;
1267 }
1268
1269 const int output_batch_leading_dim =
1270 output->dims->data[output->dims->size - 1];
1271 if (time_major) {
1272 // Feed the sequence into the LSTM step-by-step.
1273 const int input_step = n_batch * n_input;
1274 const int output_step = n_batch * output_batch_leading_dim;
1275 for (int t = 0; t < max_time; t++) {
1276 // If this is the forward_sequence, step forward, otherwise step
1277 // backwards.
1278 const int t_rel = forward_sequence ? t : max_time - t - 1;
1279 const float* input_ptr_batch = input->data.f + t_rel * input_step;
1280 if (aux_input) {
1281 aux_input_ptr = aux_input->data.f + t_rel * input_step;
1282 }
1283 float* output_ptr_batch =
1284 output->data.f + t_rel * output_step + output_offset;
1285
1286 LstmStepWithAuxInput(
1287 input_ptr_batch, input_to_input_weights_ptr,
1288 input_to_input_weights_scale, input_to_forget_weights_ptr,
1289 input_to_forget_weights_scale, input_to_cell_weights_ptr,
1290 input_to_cell_weights_scale, input_to_output_weights_ptr,
1291 input_to_output_weights_scale, aux_input_ptr,
1292 aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
1293 aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
1294 aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
1295 aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
1296 recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
1297 recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
1298 recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
1299 recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
1300 cell_to_input_weights_ptr, cell_to_input_weights_scale,
1301 cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1302 cell_to_output_weights_ptr, cell_to_output_weights_scale,
1303 input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
1304 cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
1305 input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
1306 output_gate_bias_ptr, projection_weights_ptr,
1307 projection_weights_scale, projection_bias_ptr, params, n_batch,
1308 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1309 input_gate_scratch, forget_gate_scratch, cell_scratch,
1310 output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
1311 recovered_cell_weights_ptr, quantized_input_ptr,
1312 quantized_aux_input_ptr, quantized_output_state_ptr,
1313 quantized_cell_state_ptr, output_state->data.f, cell_state->data.f,
1314 output_ptr_batch);
1315 }
1316 } else {
1317 for (int b = 0; b < n_batch; b++) {
1318 const int input_step = n_input;
1319 const int output_step = output_batch_leading_dim;
1320 for (int t = 0; t < max_time; t++) {
1321 // If this is the forward_sequence, step forward, otherwise step
1322 // backwards.
1323 const int t_rel = forward_sequence ? t : max_time - t - 1;
1324 const int time_offset = b * max_time + t_rel;
1325 const float* input_ptr = input->data.f + time_offset * input_step;
1326 if (aux_input) {
1327 aux_input_ptr = aux_input->data.f + time_offset * input_step;
1328 }
1329 float* output_ptr =
1330 output->data.f + time_offset * output_step + output_offset;
1331
1332 // Offset the {output,cell}_state pointers to the right batch.
1333 float* output_state_ptr =
1334 output_state->data.f + b * output_batch_leading_dim;
1335 float* cell_state_ptr = cell_state->data.f + b * n_cell;
1336 // Offset the scratch pointers to the right batch.
1337 float* input_gate_scratch_ptr =
1338 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1339 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1340 float* cell_scratch_ptr = cell_scratch + b * n_cell;
1341 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1342
1343 LstmStepWithAuxInput(
1344 input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
1345 input_to_forget_weights_ptr, input_to_forget_weights_scale,
1346 input_to_cell_weights_ptr, input_to_cell_weights_scale,
1347 input_to_output_weights_ptr, input_to_output_weights_scale,
1348 aux_input_ptr, aux_input_to_input_weights_ptr,
1349 aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
1350 aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
1351 aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
1352 aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
1353 recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
1354 recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
1355 recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
1356 recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
1357 cell_to_input_weights_scale, cell_to_forget_weights_ptr,
1358 cell_to_forget_weights_scale, cell_to_output_weights_ptr,
1359 cell_to_output_weights_scale, input_layer_norm_coefficients_ptr,
1360 forget_layer_norm_coefficients_ptr,
1361 cell_layer_norm_coefficients_ptr,
1362 output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
1363 forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
1364 projection_weights_ptr, projection_weights_scale,
1365 projection_bias_ptr, params,
1366 /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
1367 output_batch_leading_dim, input_gate_scratch_ptr,
1368 forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
1369 scaling_factors_ptr, prod_scaling_factors_ptr,
1370 recovered_cell_weights_ptr, quantized_input_ptr,
1371 quantized_aux_input_ptr, quantized_output_state_ptr,
1372 quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
1373 output_ptr);
1374 }
1375 }
1376 }
1377
1378 return kTfLiteOk;
1379 }
1380
1381 } // namespace lstm_eval
1382 } // namespace builtin
1383 } // namespace ops
1384 } // namespace tflite
1385