1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cassert>
17 #include <cmath>
18 #include <cstdio>
19 #include <cstdlib>
20 #include <iostream>
21 #include <limits>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/kernels/activation_functor.h"
26 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_eval.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace unidirectional_sequence_lstm {
36 
37 // Input Tensors of size {max_time, n_batch, n_input}
38 constexpr int kInputTensor = 0;
39 
40 // Input weight tensors of size: {n_cell, n_input}
41 constexpr int kInputToInputWeightsTensor = 1;  // Optional
42 constexpr int kInputToForgetWeightsTensor = 2;
43 constexpr int kInputToCellWeightsTensor = 3;
44 constexpr int kInputToOutputWeightsTensor = 4;
45 
46 // Recurrent weight tensors of size {n_cell, n_output}
47 constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
48 constexpr int kRecurrentToForgetWeightsTensor = 6;
49 constexpr int kRecurrentToCellWeightsTensor = 7;
50 constexpr int kRecurrentToOutputWeightsTensor = 8;
51 
52 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
53 constexpr int kCellToInputWeightsTensor = 9;    // Optional
54 constexpr int kCellToForgetWeightsTensor = 10;  // Optional
55 constexpr int kCellToOutputWeightsTensor = 11;  // Optional
56 
57 // Gates bias tensors of size {n_cell}
58 constexpr int kInputGateBiasTensor = 12;  // Optional
59 constexpr int kForgetGateBiasTensor = 13;
60 constexpr int kCellGateBiasTensor = 14;
61 constexpr int kOutputGateBiasTensor = 15;
62 
63 // Projection weight tensor of size {n_output, n_cell}
64 constexpr int kProjectionWeightsTensor = 16;  // Optional
65 // Projection bias tensor of size {n_output}
66 constexpr int kProjectionBiasTensor = 17;  // Optional
67 
68 // Stateful input tensors that are variables and will be modified by the Op.
69 // Activation state tensor of size {n_batch, n_output}
70 constexpr int kInputActivationStateTensor = 18;
71 // Cell state tensor of size {n_batch, n_cell}
72 constexpr int kInputCellStateTensor = 19;
73 
74 // Output tensors.
75 constexpr int kOutputTensor = 0;
76 
77 // Temporary tensors
78 enum TemporaryTensor {
79   kScratchBuffer = 0,
80   kInputQuantized = 1,
81   kOutputStateQuantized = 2,
82   kCellStateQuantized = 3,
83   kScalingFactors = 4,
84   kProductScalingFactors = 5,
85   kRecoveredCellWeights = 6,
86   kNumTemporaryTensors = 7
87 };
88 
Init(TfLiteContext * context,const char * buffer,size_t length)89 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
90   auto* scratch_tensor_index = new int();
91   context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
92   return scratch_tensor_index;
93 }
94 
Free(TfLiteContext * context,void * buffer)95 void Free(TfLiteContext* context, void* buffer) {
96   delete reinterpret_cast<int*>(buffer);
97 }
98 
99 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell)100 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
101                                         TfLiteNode* node, int n_input,
102                                         int n_output, int n_cell) {
103   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
104 
105   // Making sure clipping parameters have valid values.
106   // == 0 means no clipping
107   //  > 0 means clipping
108   TF_LITE_ENSURE(context, params->cell_clip >= 0);
109   TF_LITE_ENSURE(context, params->proj_clip >= 0);
110 
111   const TfLiteTensor* input_to_input_weights =
112       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
113   if (input_to_input_weights != nullptr) {
114     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
115     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
116     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
117   }
118 
119   const TfLiteTensor* input_to_forget_weights =
120       GetInput(context, node, kInputToForgetWeightsTensor);
121   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
122   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
123   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
124 
125   const TfLiteTensor* input_to_cell_weights =
126       GetInput(context, node, kInputToCellWeightsTensor);
127   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
128   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
129   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
130 
131   const TfLiteTensor* recurrent_to_input_weights =
132       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
133   if (recurrent_to_input_weights != nullptr) {
134     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
135     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
136                       n_cell);
137     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
138                       n_output);
139   }
140 
141   const TfLiteTensor* recurrent_to_forget_weights =
142       GetInput(context, node, kRecurrentToForgetWeightsTensor);
143   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
144   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
145                     n_cell);
146   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
147                     n_output);
148 
149   const TfLiteTensor* recurrent_to_cell_weights =
150       GetInput(context, node, kRecurrentToCellWeightsTensor);
151   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
152   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
153   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
154                     n_output);
155 
156   // We make sure the input-gate's parameters are either both present (regular
157   // LSTM) or not at all (CIFG-LSTM).
158   const bool cifg_weights_all_or_none =
159       ((input_to_input_weights != nullptr) &&
160        (recurrent_to_input_weights != nullptr)) ||
161       ((input_to_input_weights == nullptr) &&
162        (recurrent_to_input_weights == nullptr));
163   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
164 
165   const TfLiteTensor* cell_to_input_weights =
166       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
167   if (cell_to_input_weights != nullptr) {
168     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
169     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
170   }
171 
172   const TfLiteTensor* cell_to_forget_weights =
173       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
174   if (cell_to_forget_weights != nullptr) {
175     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
176     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
177   }
178 
179   const TfLiteTensor* cell_to_output_weights =
180       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
181   if (cell_to_output_weights != nullptr) {
182     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
183     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
184   }
185 
186   // Making sure the peephole weights are there all or none.
187   const bool use_cifg = (input_to_input_weights == nullptr);
188   const bool peephole_weights_all_or_none =
189       ((cell_to_input_weights != nullptr || use_cifg) &&
190        (cell_to_forget_weights != nullptr) &&
191        (cell_to_output_weights != nullptr)) ||
192       ((cell_to_input_weights == nullptr) &&
193        (cell_to_forget_weights == nullptr) &&
194        (cell_to_output_weights == nullptr));
195   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
196 
197   // Make sure the input gate bias is present only when not a CIFG-LSTM.
198   const TfLiteTensor* input_gate_bias =
199       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
200   if (use_cifg) {
201     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
202   } else {
203     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
204     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
205   }
206 
207   const TfLiteTensor* forget_gate_bias =
208       GetInput(context, node, kForgetGateBiasTensor);
209   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
210   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
211 
212   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
213   TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
214   TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
215 
216   const TfLiteTensor* output_gate_bias =
217       GetInput(context, node, kOutputGateBiasTensor);
218   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
219   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
220 
221   const TfLiteTensor* projection_weights =
222       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
223   if (projection_weights != nullptr) {
224     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
225     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
226     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
227   }
228 
229   const TfLiteTensor* projection_bias =
230       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
231   if (projection_bias != nullptr) {
232     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
233     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
234   }
235 
236   // Making sure the projection tensors are consistent:
237   // 1) If projection weight is not present, then projection bias should not be
238   // present.
239   // 2) If projection weight is present, then projection bias is optional.
240   // TODO(ghodrat): make sure this is correct.
241   const bool projecton_tensors_consistent =
242       ((projection_weights != nullptr) || (projection_bias == nullptr));
243   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
244 
245   return kTfLiteOk;
246 }
247 
248 // Resize the output and  state tensors based on the sizes of the input tensors.
249 // Allocate a temporary scratch tensor. Also check that the sizes of the input
250 // tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)251 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
252   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
253 
254   // Check we have all the inputs and outputs we need.
255   TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
256   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
257 
258   // Inferring batch size, number of outputs and sequence length and
259   // number of cells from the input tensors.
260   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
261   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
262   TF_LITE_ENSURE(context, input->dims->size > 1);
263   const auto* params =
264       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
265           node->builtin_data);
266   const bool time_major = params->time_major;
267   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
268   const int n_input = input->dims->data[2];
269 
270   const TfLiteTensor* input_to_output_weights =
271       GetInput(context, node, kInputToOutputWeightsTensor);
272   const int n_cell = input_to_output_weights->dims->data[0];
273   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
274   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
275 
276   const TfLiteTensor* recurrent_to_output_weights =
277       GetInput(context, node, kRecurrentToOutputWeightsTensor);
278   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
279   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
280                     n_cell);
281   const int n_output = recurrent_to_output_weights->dims->data[1];
282 
283   // Check that input tensor dimensions matches with each other.
284   TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
285                                                         n_output, n_cell));
286 
287   // Get the pointer to output, activation_state and cell_state buffer tensors.
288   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
289 
290   TfLiteTensor* activation_state =
291       GetVariableInput(context, node, kInputActivationStateTensor);
292   TfLiteTensor* cell_state =
293       GetVariableInput(context, node, kInputCellStateTensor);
294 
295   // Check the shape of input state tensors.
296   // These tensor may be 1D or 2D. It's fine as long as the total size is
297   // correct.
298   TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
299   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
300 
301   // Resize the output tensors.
302   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
303   output_size->data[input->dims->size - 1] = n_output;
304   TF_LITE_ENSURE_OK(context,
305                     context->ResizeTensor(context, output, output_size));
306 
307   // The weights are of consistent type, so it suffices to check one.
308   // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
309   const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
310                               input_to_output_weights->type == kTfLiteInt8) &&
311                              input->type == kTfLiteFloat32);
312 
313   TfLiteIntArrayFree(node->temporaries);
314   if (is_hybrid_op) {
315     node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
316   } else {
317     node->temporaries = TfLiteIntArrayCreate(1);
318   }
319   node->temporaries->data[0] = *scratch_tensor_index;
320 
321   // Create a scratch buffer tensor.
322   TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
323   scratch_buffer->type = input->type;
324   scratch_buffer->allocation_type = kTfLiteArenaRw;
325 
326   const TfLiteTensor* input_to_input_weights =
327       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
328   const bool use_cifg = (input_to_input_weights == nullptr);
329   TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
330   scratch_buffer_size->data[0] = n_batch;
331   if (use_cifg) {
332     // Reserving space for Cell, Forget, Output gates
333     scratch_buffer_size->data[1] = n_cell * 3;
334   } else {
335     // Reserving space for Input, Cell, Forget, Output gates
336     scratch_buffer_size->data[1] = n_cell * 4;
337   }
338   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
339                                                    scratch_buffer_size));
340 
341   if (is_hybrid_op) {
342     // Allocate temporary tensors to store quantized values of input,
343     // activation_state and cell_state tensors.
344     node->temporaries->data[kInputQuantized] =
345         *scratch_tensor_index + kInputQuantized;
346     TfLiteTensor* input_quantized =
347         GetTemporary(context, node, kInputQuantized);
348     input_quantized->type = input_to_output_weights->type;
349     input_quantized->allocation_type = kTfLiteArenaRw;
350     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
351       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
352       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
353                                                        input_quantized_size));
354     }
355     node->temporaries->data[kOutputStateQuantized] =
356         *scratch_tensor_index + kOutputStateQuantized;
357     TfLiteTensor* activation_state_quantized =
358         GetTemporary(context, node, kOutputStateQuantized);
359     activation_state_quantized->type = input_to_output_weights->type;
360     activation_state_quantized->allocation_type = kTfLiteArenaRw;
361     if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
362                              activation_state->dims)) {
363       TfLiteIntArray* activation_state_quantized_size =
364           TfLiteIntArrayCopy(activation_state->dims);
365       TF_LITE_ENSURE_OK(
366           context, context->ResizeTensor(context, activation_state_quantized,
367                                          activation_state_quantized_size));
368     }
369     node->temporaries->data[kCellStateQuantized] =
370         *scratch_tensor_index + kCellStateQuantized;
371     TfLiteTensor* cell_state_quantized =
372         GetTemporary(context, node, kCellStateQuantized);
373     cell_state_quantized->type = input_to_output_weights->type;
374     cell_state_quantized->allocation_type = kTfLiteArenaRw;
375     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
376       TfLiteIntArray* cell_state_quantized_size =
377           TfLiteIntArrayCopy(cell_state->dims);
378       TF_LITE_ENSURE_OK(context,
379                         context->ResizeTensor(context, cell_state_quantized,
380                                               cell_state_quantized_size));
381     }
382 
383     // Allocate temporary tensors to store scaling factors and product scaling
384     // factors. The latter is a convenience storage which allows to quantize
385     // a vector once (which produces the scaling factors) and multiply it with
386     // different matrices (which requires multiplying the scaling factors with
387     // the scaling factor of the matrix).
388     node->temporaries->data[kScalingFactors] =
389         *scratch_tensor_index + kScalingFactors;
390     TfLiteTensor* scaling_factors =
391         GetTemporary(context, node, kScalingFactors);
392     scaling_factors->type = kTfLiteFloat32;
393     scaling_factors->allocation_type = kTfLiteArenaRw;
394     int scaling_dims[1] = {n_batch};
395     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
396       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
397       scaling_factors_size->data[0] = n_batch;
398       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
399                                                        scaling_factors_size));
400     }
401     node->temporaries->data[kProductScalingFactors] =
402         *scratch_tensor_index + kProductScalingFactors;
403     TfLiteTensor* prod_scaling_factors =
404         GetTemporary(context, node, kProductScalingFactors);
405     prod_scaling_factors->type = kTfLiteFloat32;
406     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
407     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
408                                    scaling_dims)) {
409       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
410       prod_scaling_factors_size->data[0] = n_batch;
411       TF_LITE_ENSURE_OK(context,
412                         context->ResizeTensor(context, prod_scaling_factors,
413                                               prod_scaling_factors_size));
414     }
415 
416     // Allocate a temporary tensor to store the recovered cell weights. Since
417     // this is used for diagonal matrices, only need to store n_cell values.
418     node->temporaries->data[kRecoveredCellWeights] =
419         *scratch_tensor_index + kRecoveredCellWeights;
420     TfLiteTensor* recovered_cell_weights =
421         GetTemporary(context, node, kRecoveredCellWeights);
422     recovered_cell_weights->type = kTfLiteFloat32;
423     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
424     int recovered_cell_dims[1] = {n_cell};
425     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
426                                    recovered_cell_dims)) {
427       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
428       recovered_cell_weights_size->data[0] = n_cell;
429       TF_LITE_ENSURE_OK(context,
430                         context->ResizeTensor(context, recovered_cell_weights,
431                                               recovered_cell_weights_size));
432     }
433   }
434   return kTfLiteOk;
435 }
436 
Eval(TfLiteContext * context,TfLiteNode * node)437 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
438   const auto* params =
439       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
440           node->builtin_data);
441   const bool time_major = params->time_major;
442   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
443 
444   const TfLiteTensor* input_to_input_weights =
445       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
446   const TfLiteTensor* input_to_forget_weights =
447       GetInput(context, node, kInputToForgetWeightsTensor);
448   const TfLiteTensor* input_to_cell_weights =
449       GetInput(context, node, kInputToCellWeightsTensor);
450   const TfLiteTensor* input_to_output_weights =
451       GetInput(context, node, kInputToOutputWeightsTensor);
452 
453   const TfLiteTensor* recurrent_to_input_weights =
454       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
455   const TfLiteTensor* recurrent_to_forget_weights =
456       GetInput(context, node, kRecurrentToForgetWeightsTensor);
457   const TfLiteTensor* recurrent_to_cell_weights =
458       GetInput(context, node, kRecurrentToCellWeightsTensor);
459   const TfLiteTensor* recurrent_to_output_weights =
460       GetInput(context, node, kRecurrentToOutputWeightsTensor);
461 
462   const TfLiteTensor* cell_to_input_weights =
463       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
464   const TfLiteTensor* cell_to_forget_weights =
465       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
466   const TfLiteTensor* cell_to_output_weights =
467       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
468 
469   const TfLiteTensor* input_gate_bias =
470       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
471   const TfLiteTensor* forget_gate_bias =
472       GetInput(context, node, kForgetGateBiasTensor);
473   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
474   const TfLiteTensor* output_gate_bias =
475       GetInput(context, node, kOutputGateBiasTensor);
476 
477   const TfLiteTensor* projection_weights =
478       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
479   const TfLiteTensor* projection_bias =
480       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
481 
482   // Index the scratch buffers pointers to the global scratch buffer.
483   TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
484 
485   TfLiteTensor* activation_state =
486       GetVariableInput(context, node, kInputActivationStateTensor);
487   TfLiteTensor* cell_state =
488       GetVariableInput(context, node, kInputCellStateTensor);
489 
490   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
491 
492   // Copy out the LSTM specific params so they can be passed in the function.
493   TfLiteLSTMParams lstm_params;
494   lstm_params.activation = params->activation;
495   lstm_params.cell_clip = params->cell_clip;
496   lstm_params.proj_clip = params->proj_clip;
497 
498   switch (input_to_output_weights->type) {
499     case kTfLiteFloat32: {
500       return lstm_eval::EvalFloat(
501           input, input_to_input_weights, input_to_forget_weights,
502           input_to_cell_weights, input_to_output_weights,
503           recurrent_to_input_weights, recurrent_to_forget_weights,
504           recurrent_to_cell_weights, recurrent_to_output_weights,
505           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
506           /*input_layer_norm_coefficients=*/nullptr,
507           /*forget_layer_norm_coefficients=*/nullptr,
508           /*cell_layer_norm_coefficients=*/nullptr,
509           /*output_layer_norm_coefficients=*/nullptr,
510           /*aux_input=*/nullptr,
511           /*aux_input_to_input_weights=*/nullptr,
512           /*aux_input_to_forget_weights=*/nullptr,
513           /*aux_input_to_cell_weights=*/nullptr,
514           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
515           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
516           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
517           /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
518           output);
519     }
520     case kTfLiteUInt8:
521     case kTfLiteInt8: {
522       TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
523       TfLiteTensor* activation_state_quantized =
524           GetTemporary(context, node, /*index=*/2);
525       TfLiteTensor* cell_state_quantized =
526           GetTemporary(context, node, /*index=*/3);
527       TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
528       TfLiteTensor* prod_scaling_factors =
529           GetTemporary(context, node, /*index=*/5);
530       TfLiteTensor* recovered_cell_weights =
531           GetTemporary(context, node, /*index=*/6);
532       return lstm_eval::EvalHybrid(
533           input, input_to_input_weights, input_to_forget_weights,
534           input_to_cell_weights, input_to_output_weights,
535           recurrent_to_input_weights, recurrent_to_forget_weights,
536           recurrent_to_cell_weights, recurrent_to_output_weights,
537           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
538           /*input_layer_norm_coefficients=*/nullptr,
539           /*forget_layer_norm_coefficients=*/nullptr,
540           /*cell_layer_norm_coefficients=*/nullptr,
541           /*output_layer_norm_coefficients=*/nullptr,
542           /*aux_input=*/nullptr,
543           /*aux_input_to_input_weights=*/nullptr,
544           /*aux_input_to_forget_weights=*/nullptr,
545           /*aux_input_to_cell_weights=*/nullptr,
546           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
547           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
548           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
549           /*output_offset=*/0, scratch_buffer, scaling_factors,
550           prod_scaling_factors, recovered_cell_weights, input_quantized,
551           /*aux_input_quantized=*/nullptr, activation_state_quantized,
552           cell_state_quantized, activation_state, cell_state, output);
553     }
554     default:
555       context->ReportError(context, "Type %d is not currently supported.",
556                            input_to_output_weights->type);
557       return kTfLiteError;
558   }
559   return kTfLiteOk;
560 }
561 }  // namespace unidirectional_sequence_lstm
562 
Register_UNIDIRECTIONAL_SEQUENCE_LSTM()563 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
564   static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
565                                  unidirectional_sequence_lstm::Free,
566                                  unidirectional_sequence_lstm::Prepare,
567                                  unidirectional_sequence_lstm::Eval};
568   return &r;
569 }
570 
571 }  // namespace builtin
572 }  // namespace ops
573 }  // namespace tflite
574