1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cassert>
17 #include <cmath>
18 #include <cstdio>
19 #include <cstdlib>
20 #include <iostream>
21 #include <limits>
22
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/kernels/activation_functor.h"
26 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_eval.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace unidirectional_sequence_lstm {
36
37 // Input Tensors of size {max_time, n_batch, n_input}
38 constexpr int kInputTensor = 0;
39
40 // Input weight tensors of size: {n_cell, n_input}
41 constexpr int kInputToInputWeightsTensor = 1; // Optional
42 constexpr int kInputToForgetWeightsTensor = 2;
43 constexpr int kInputToCellWeightsTensor = 3;
44 constexpr int kInputToOutputWeightsTensor = 4;
45
46 // Recurrent weight tensors of size {n_cell, n_output}
47 constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
48 constexpr int kRecurrentToForgetWeightsTensor = 6;
49 constexpr int kRecurrentToCellWeightsTensor = 7;
50 constexpr int kRecurrentToOutputWeightsTensor = 8;
51
52 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
53 constexpr int kCellToInputWeightsTensor = 9; // Optional
54 constexpr int kCellToForgetWeightsTensor = 10; // Optional
55 constexpr int kCellToOutputWeightsTensor = 11; // Optional
56
57 // Gates bias tensors of size {n_cell}
58 constexpr int kInputGateBiasTensor = 12; // Optional
59 constexpr int kForgetGateBiasTensor = 13;
60 constexpr int kCellGateBiasTensor = 14;
61 constexpr int kOutputGateBiasTensor = 15;
62
63 // Projection weight tensor of size {n_output, n_cell}
64 constexpr int kProjectionWeightsTensor = 16; // Optional
65 // Projection bias tensor of size {n_output}
66 constexpr int kProjectionBiasTensor = 17; // Optional
67
68 // Stateful input tensors that are variables and will be modified by the Op.
69 // Activation state tensor of size {n_batch, n_output}
70 constexpr int kInputActivationStateTensor = 18;
71 // Cell state tensor of size {n_batch, n_cell}
72 constexpr int kInputCellStateTensor = 19;
73
74 // Output tensors.
75 constexpr int kOutputTensor = 0;
76
77 // Temporary tensors
78 enum TemporaryTensor {
79 kScratchBuffer = 0,
80 kInputQuantized = 1,
81 kOutputStateQuantized = 2,
82 kCellStateQuantized = 3,
83 kScalingFactors = 4,
84 kProductScalingFactors = 5,
85 kRecoveredCellWeights = 6,
86 kNumTemporaryTensors = 7
87 };
88
Init(TfLiteContext * context,const char * buffer,size_t length)89 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
90 auto* scratch_tensor_index = new int();
91 context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
92 return scratch_tensor_index;
93 }
94
Free(TfLiteContext * context,void * buffer)95 void Free(TfLiteContext* context, void* buffer) {
96 delete reinterpret_cast<int*>(buffer);
97 }
98
99 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell)100 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
101 TfLiteNode* node, int n_input,
102 int n_output, int n_cell) {
103 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
104
105 // Making sure clipping parameters have valid values.
106 // == 0 means no clipping
107 // > 0 means clipping
108 TF_LITE_ENSURE(context, params->cell_clip >= 0);
109 TF_LITE_ENSURE(context, params->proj_clip >= 0);
110
111 const TfLiteTensor* input_to_input_weights =
112 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
113 if (input_to_input_weights != nullptr) {
114 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
115 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
116 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
117 }
118
119 const TfLiteTensor* input_to_forget_weights =
120 GetInput(context, node, kInputToForgetWeightsTensor);
121 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
122 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
123 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
124
125 const TfLiteTensor* input_to_cell_weights =
126 GetInput(context, node, kInputToCellWeightsTensor);
127 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
128 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
129 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
130
131 const TfLiteTensor* recurrent_to_input_weights =
132 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
133 if (recurrent_to_input_weights != nullptr) {
134 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
135 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
136 n_cell);
137 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
138 n_output);
139 }
140
141 const TfLiteTensor* recurrent_to_forget_weights =
142 GetInput(context, node, kRecurrentToForgetWeightsTensor);
143 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
144 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
145 n_cell);
146 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
147 n_output);
148
149 const TfLiteTensor* recurrent_to_cell_weights =
150 GetInput(context, node, kRecurrentToCellWeightsTensor);
151 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
152 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
153 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
154 n_output);
155
156 // We make sure the input-gate's parameters are either both present (regular
157 // LSTM) or not at all (CIFG-LSTM).
158 const bool cifg_weights_all_or_none =
159 ((input_to_input_weights != nullptr) &&
160 (recurrent_to_input_weights != nullptr)) ||
161 ((input_to_input_weights == nullptr) &&
162 (recurrent_to_input_weights == nullptr));
163 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
164
165 const TfLiteTensor* cell_to_input_weights =
166 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
167 if (cell_to_input_weights != nullptr) {
168 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
169 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
170 }
171
172 const TfLiteTensor* cell_to_forget_weights =
173 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
174 if (cell_to_forget_weights != nullptr) {
175 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
176 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
177 }
178
179 const TfLiteTensor* cell_to_output_weights =
180 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
181 if (cell_to_output_weights != nullptr) {
182 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
183 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
184 }
185
186 // Making sure the peephole weights are there all or none.
187 const bool use_cifg = (input_to_input_weights == nullptr);
188 const bool peephole_weights_all_or_none =
189 ((cell_to_input_weights != nullptr || use_cifg) &&
190 (cell_to_forget_weights != nullptr) &&
191 (cell_to_output_weights != nullptr)) ||
192 ((cell_to_input_weights == nullptr) &&
193 (cell_to_forget_weights == nullptr) &&
194 (cell_to_output_weights == nullptr));
195 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
196
197 // Make sure the input gate bias is present only when not a CIFG-LSTM.
198 const TfLiteTensor* input_gate_bias =
199 GetOptionalInputTensor(context, node, kInputGateBiasTensor);
200 if (use_cifg) {
201 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
202 } else {
203 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
204 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
205 }
206
207 const TfLiteTensor* forget_gate_bias =
208 GetInput(context, node, kForgetGateBiasTensor);
209 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
210 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
211
212 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
213 TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
214 TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
215
216 const TfLiteTensor* output_gate_bias =
217 GetInput(context, node, kOutputGateBiasTensor);
218 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
219 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
220
221 const TfLiteTensor* projection_weights =
222 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
223 if (projection_weights != nullptr) {
224 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
225 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
226 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
227 }
228
229 const TfLiteTensor* projection_bias =
230 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
231 if (projection_bias != nullptr) {
232 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
233 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
234 }
235
236 // Making sure the projection tensors are consistent:
237 // 1) If projection weight is not present, then projection bias should not be
238 // present.
239 // 2) If projection weight is present, then projection bias is optional.
240 // TODO(ghodrat): make sure this is correct.
241 const bool projecton_tensors_consistent =
242 ((projection_weights != nullptr) || (projection_bias == nullptr));
243 TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
244
245 return kTfLiteOk;
246 }
247
248 // Resize the output and state tensors based on the sizes of the input tensors.
249 // Allocate a temporary scratch tensor. Also check that the sizes of the input
250 // tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)251 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
252 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
253
254 // Check we have all the inputs and outputs we need.
255 TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
256 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
257
258 // Inferring batch size, number of outputs and sequence length and
259 // number of cells from the input tensors.
260 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
261 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
262 TF_LITE_ENSURE(context, input->dims->size > 1);
263 const auto* params =
264 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
265 node->builtin_data);
266 const bool time_major = params->time_major;
267 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
268 const int n_input = input->dims->data[2];
269
270 const TfLiteTensor* input_to_output_weights =
271 GetInput(context, node, kInputToOutputWeightsTensor);
272 const int n_cell = input_to_output_weights->dims->data[0];
273 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
274 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
275
276 const TfLiteTensor* recurrent_to_output_weights =
277 GetInput(context, node, kRecurrentToOutputWeightsTensor);
278 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
279 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
280 n_cell);
281 const int n_output = recurrent_to_output_weights->dims->data[1];
282
283 // Check that input tensor dimensions matches with each other.
284 TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
285 n_output, n_cell));
286
287 // Get the pointer to output, activation_state and cell_state buffer tensors.
288 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
289
290 TfLiteTensor* activation_state =
291 GetVariableInput(context, node, kInputActivationStateTensor);
292 TfLiteTensor* cell_state =
293 GetVariableInput(context, node, kInputCellStateTensor);
294
295 // Check the shape of input state tensors.
296 // These tensor may be 1D or 2D. It's fine as long as the total size is
297 // correct.
298 TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
299 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
300
301 // Resize the output tensors.
302 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
303 output_size->data[input->dims->size - 1] = n_output;
304 TF_LITE_ENSURE_OK(context,
305 context->ResizeTensor(context, output, output_size));
306
307 // The weights are of consistent type, so it suffices to check one.
308 // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
309 const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
310 input_to_output_weights->type == kTfLiteInt8) &&
311 input->type == kTfLiteFloat32);
312
313 TfLiteIntArrayFree(node->temporaries);
314 if (is_hybrid_op) {
315 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
316 } else {
317 node->temporaries = TfLiteIntArrayCreate(1);
318 }
319 node->temporaries->data[0] = *scratch_tensor_index;
320
321 // Create a scratch buffer tensor.
322 TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
323 scratch_buffer->type = input->type;
324 scratch_buffer->allocation_type = kTfLiteArenaRw;
325
326 const TfLiteTensor* input_to_input_weights =
327 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
328 const bool use_cifg = (input_to_input_weights == nullptr);
329 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
330 scratch_buffer_size->data[0] = n_batch;
331 if (use_cifg) {
332 // Reserving space for Cell, Forget, Output gates
333 scratch_buffer_size->data[1] = n_cell * 3;
334 } else {
335 // Reserving space for Input, Cell, Forget, Output gates
336 scratch_buffer_size->data[1] = n_cell * 4;
337 }
338 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
339 scratch_buffer_size));
340
341 if (is_hybrid_op) {
342 // Allocate temporary tensors to store quantized values of input,
343 // activation_state and cell_state tensors.
344 node->temporaries->data[kInputQuantized] =
345 *scratch_tensor_index + kInputQuantized;
346 TfLiteTensor* input_quantized =
347 GetTemporary(context, node, kInputQuantized);
348 input_quantized->type = input_to_output_weights->type;
349 input_quantized->allocation_type = kTfLiteArenaRw;
350 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
351 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
352 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
353 input_quantized_size));
354 }
355 node->temporaries->data[kOutputStateQuantized] =
356 *scratch_tensor_index + kOutputStateQuantized;
357 TfLiteTensor* activation_state_quantized =
358 GetTemporary(context, node, kOutputStateQuantized);
359 activation_state_quantized->type = input_to_output_weights->type;
360 activation_state_quantized->allocation_type = kTfLiteArenaRw;
361 if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
362 activation_state->dims)) {
363 TfLiteIntArray* activation_state_quantized_size =
364 TfLiteIntArrayCopy(activation_state->dims);
365 TF_LITE_ENSURE_OK(
366 context, context->ResizeTensor(context, activation_state_quantized,
367 activation_state_quantized_size));
368 }
369 node->temporaries->data[kCellStateQuantized] =
370 *scratch_tensor_index + kCellStateQuantized;
371 TfLiteTensor* cell_state_quantized =
372 GetTemporary(context, node, kCellStateQuantized);
373 cell_state_quantized->type = input_to_output_weights->type;
374 cell_state_quantized->allocation_type = kTfLiteArenaRw;
375 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
376 TfLiteIntArray* cell_state_quantized_size =
377 TfLiteIntArrayCopy(cell_state->dims);
378 TF_LITE_ENSURE_OK(context,
379 context->ResizeTensor(context, cell_state_quantized,
380 cell_state_quantized_size));
381 }
382
383 // Allocate temporary tensors to store scaling factors and product scaling
384 // factors. The latter is a convenience storage which allows to quantize
385 // a vector once (which produces the scaling factors) and multiply it with
386 // different matrices (which requires multiplying the scaling factors with
387 // the scaling factor of the matrix).
388 node->temporaries->data[kScalingFactors] =
389 *scratch_tensor_index + kScalingFactors;
390 TfLiteTensor* scaling_factors =
391 GetTemporary(context, node, kScalingFactors);
392 scaling_factors->type = kTfLiteFloat32;
393 scaling_factors->allocation_type = kTfLiteArenaRw;
394 int scaling_dims[1] = {n_batch};
395 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
396 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
397 scaling_factors_size->data[0] = n_batch;
398 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
399 scaling_factors_size));
400 }
401 node->temporaries->data[kProductScalingFactors] =
402 *scratch_tensor_index + kProductScalingFactors;
403 TfLiteTensor* prod_scaling_factors =
404 GetTemporary(context, node, kProductScalingFactors);
405 prod_scaling_factors->type = kTfLiteFloat32;
406 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
407 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
408 scaling_dims)) {
409 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
410 prod_scaling_factors_size->data[0] = n_batch;
411 TF_LITE_ENSURE_OK(context,
412 context->ResizeTensor(context, prod_scaling_factors,
413 prod_scaling_factors_size));
414 }
415
416 // Allocate a temporary tensor to store the recovered cell weights. Since
417 // this is used for diagonal matrices, only need to store n_cell values.
418 node->temporaries->data[kRecoveredCellWeights] =
419 *scratch_tensor_index + kRecoveredCellWeights;
420 TfLiteTensor* recovered_cell_weights =
421 GetTemporary(context, node, kRecoveredCellWeights);
422 recovered_cell_weights->type = kTfLiteFloat32;
423 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
424 int recovered_cell_dims[1] = {n_cell};
425 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
426 recovered_cell_dims)) {
427 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
428 recovered_cell_weights_size->data[0] = n_cell;
429 TF_LITE_ENSURE_OK(context,
430 context->ResizeTensor(context, recovered_cell_weights,
431 recovered_cell_weights_size));
432 }
433 }
434 return kTfLiteOk;
435 }
436
Eval(TfLiteContext * context,TfLiteNode * node)437 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
438 const auto* params =
439 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
440 node->builtin_data);
441 const bool time_major = params->time_major;
442 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
443
444 const TfLiteTensor* input_to_input_weights =
445 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
446 const TfLiteTensor* input_to_forget_weights =
447 GetInput(context, node, kInputToForgetWeightsTensor);
448 const TfLiteTensor* input_to_cell_weights =
449 GetInput(context, node, kInputToCellWeightsTensor);
450 const TfLiteTensor* input_to_output_weights =
451 GetInput(context, node, kInputToOutputWeightsTensor);
452
453 const TfLiteTensor* recurrent_to_input_weights =
454 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
455 const TfLiteTensor* recurrent_to_forget_weights =
456 GetInput(context, node, kRecurrentToForgetWeightsTensor);
457 const TfLiteTensor* recurrent_to_cell_weights =
458 GetInput(context, node, kRecurrentToCellWeightsTensor);
459 const TfLiteTensor* recurrent_to_output_weights =
460 GetInput(context, node, kRecurrentToOutputWeightsTensor);
461
462 const TfLiteTensor* cell_to_input_weights =
463 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
464 const TfLiteTensor* cell_to_forget_weights =
465 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
466 const TfLiteTensor* cell_to_output_weights =
467 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
468
469 const TfLiteTensor* input_gate_bias =
470 GetOptionalInputTensor(context, node, kInputGateBiasTensor);
471 const TfLiteTensor* forget_gate_bias =
472 GetInput(context, node, kForgetGateBiasTensor);
473 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
474 const TfLiteTensor* output_gate_bias =
475 GetInput(context, node, kOutputGateBiasTensor);
476
477 const TfLiteTensor* projection_weights =
478 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
479 const TfLiteTensor* projection_bias =
480 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
481
482 // Index the scratch buffers pointers to the global scratch buffer.
483 TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
484
485 TfLiteTensor* activation_state =
486 GetVariableInput(context, node, kInputActivationStateTensor);
487 TfLiteTensor* cell_state =
488 GetVariableInput(context, node, kInputCellStateTensor);
489
490 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
491
492 // Copy out the LSTM specific params so they can be passed in the function.
493 TfLiteLSTMParams lstm_params;
494 lstm_params.activation = params->activation;
495 lstm_params.cell_clip = params->cell_clip;
496 lstm_params.proj_clip = params->proj_clip;
497
498 switch (input_to_output_weights->type) {
499 case kTfLiteFloat32: {
500 return lstm_eval::EvalFloat(
501 input, input_to_input_weights, input_to_forget_weights,
502 input_to_cell_weights, input_to_output_weights,
503 recurrent_to_input_weights, recurrent_to_forget_weights,
504 recurrent_to_cell_weights, recurrent_to_output_weights,
505 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
506 /*input_layer_norm_coefficients=*/nullptr,
507 /*forget_layer_norm_coefficients=*/nullptr,
508 /*cell_layer_norm_coefficients=*/nullptr,
509 /*output_layer_norm_coefficients=*/nullptr,
510 /*aux_input=*/nullptr,
511 /*aux_input_to_input_weights=*/nullptr,
512 /*aux_input_to_forget_weights=*/nullptr,
513 /*aux_input_to_cell_weights=*/nullptr,
514 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
515 forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
516 projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
517 /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
518 output);
519 }
520 case kTfLiteUInt8:
521 case kTfLiteInt8: {
522 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
523 TfLiteTensor* activation_state_quantized =
524 GetTemporary(context, node, /*index=*/2);
525 TfLiteTensor* cell_state_quantized =
526 GetTemporary(context, node, /*index=*/3);
527 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
528 TfLiteTensor* prod_scaling_factors =
529 GetTemporary(context, node, /*index=*/5);
530 TfLiteTensor* recovered_cell_weights =
531 GetTemporary(context, node, /*index=*/6);
532 return lstm_eval::EvalHybrid(
533 input, input_to_input_weights, input_to_forget_weights,
534 input_to_cell_weights, input_to_output_weights,
535 recurrent_to_input_weights, recurrent_to_forget_weights,
536 recurrent_to_cell_weights, recurrent_to_output_weights,
537 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
538 /*input_layer_norm_coefficients=*/nullptr,
539 /*forget_layer_norm_coefficients=*/nullptr,
540 /*cell_layer_norm_coefficients=*/nullptr,
541 /*output_layer_norm_coefficients=*/nullptr,
542 /*aux_input=*/nullptr,
543 /*aux_input_to_input_weights=*/nullptr,
544 /*aux_input_to_forget_weights=*/nullptr,
545 /*aux_input_to_cell_weights=*/nullptr,
546 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
547 forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
548 projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
549 /*output_offset=*/0, scratch_buffer, scaling_factors,
550 prod_scaling_factors, recovered_cell_weights, input_quantized,
551 /*aux_input_quantized=*/nullptr, activation_state_quantized,
552 cell_state_quantized, activation_state, cell_state, output);
553 }
554 default:
555 context->ReportError(context, "Type %d is not currently supported.",
556 input_to_output_weights->type);
557 return kTfLiteError;
558 }
559 return kTfLiteOk;
560 }
561 } // namespace unidirectional_sequence_lstm
562
Register_UNIDIRECTIONAL_SEQUENCE_LSTM()563 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
564 static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
565 unidirectional_sequence_lstm::Free,
566 unidirectional_sequence_lstm::Prepare,
567 unidirectional_sequence_lstm::Eval};
568 return &r;
569 }
570
571 } // namespace builtin
572 } // namespace ops
573 } // namespace tflite
574