1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
17
18 #include <optional>
19 #include <string>
20 #include <utility>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
30 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
31 #include "tensorflow/lite/delegates/gpu/common/operations.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/tensor.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/internal/types.h"
39 #include "tensorflow/lite/kernels/lstm_shared.h"
40 #include "tensorflow/lite/string_type.h"
41
42 namespace tflite {
43 namespace gpu {
44 namespace {
45
CreateNewSimilarValue(GraphFloat32 * graph,const Value * old_value)46 Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) {
47 Value* new_value = graph->NewValue();
48 new_value->quant_params = old_value->quant_params;
49 new_value->tensor.shape = old_value->tensor.shape;
50 new_value->tensor.type = old_value->tensor.type;
51 new_value->tensor.ref = -1;
52 return new_value;
53 }
54
SetFullyConnectedWeights(int weights_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)55 absl::Status SetFullyConnectedWeights(int weights_tensor_id,
56 ObjectReader* reader,
57 FullyConnectedAttributes* attr) {
58 Tensor<HW, DataType::FLOAT32> weights;
59 RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
60 attr->weights.data = std::move(weights.data);
61 attr->weights.id = weights.id;
62 attr->weights.shape.o = weights.shape.h;
63 attr->weights.shape.h = 1;
64 attr->weights.shape.w = 1;
65 attr->weights.shape.i = weights.shape.w;
66 return absl::OkStatus();
67 }
68
HasTensor(const TfLiteNode * node,const int index)69 bool HasTensor(const TfLiteNode* node, const int index) {
70 return (index < node->inputs->size) &&
71 (node->inputs->data[index] != kTfLiteOptionalTensor);
72 }
73
HasCifg(const TfLiteNode * node)74 bool HasCifg(const TfLiteNode* node) {
75 return !HasTensor(
76 node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor);
77 }
78
HasPeephole(const TfLiteNode * node)79 bool HasPeephole(const TfLiteNode* node) {
80 // Use forget weights to detect peephole instead of input weights as input
81 // weights may be missing for cifg.
82 return HasTensor(
83 node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor);
84 }
85
HasNormalization(const TfLiteNode * node)86 bool HasNormalization(const TfLiteNode* node) {
87 return HasTensor(
88 node,
89 tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
90 }
91
HasProjection(const TfLiteNode * node)92 bool HasProjection(const TfLiteNode* node) {
93 return HasTensor(node,
94 tflite::ops::builtin::lstm::full::kProjectionWeightsTensor);
95 }
96
97 // Builds subgraph for a single LSTM gate.
98 // Returns a Value representing the gate's output.
99 // High-level parameters:
100 // - Has normalization (if true: provide normalization weights).
101 // - Has peephole connection (if true: provide peephole weights).
102 // - Which activation function to use.
103 // Note: no support for aux input.
104 //
105 // Implements the following:
106 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
107 // temp = input_weights * input_tensor + recurrent_weights * output_state;
108 // if (peephole):
109 // temp += peephole_weights .* cell_state;
110 // if (layer normalization):
111 // gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias);
112 // else:
113 // gate = activate(temp + bias);
114 //
BuildLstmGate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * cell_state,int input_weight_id,int recurrent_weight_id,int cell_weight_id,int bias_id,int normalization_weight_id,const TfLiteFusedActivation activation,bool has_peephole,bool has_normalization,Value ** gate_out)115 absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader,
116 Value* output_state, Value* cell_state,
117 int input_weight_id, int recurrent_weight_id,
118 int cell_weight_id, int bias_id,
119 int normalization_weight_id,
120 const TfLiteFusedActivation activation,
121 bool has_peephole, bool has_normalization,
122 Value** gate_out) {
123 Value* input_times_weights = CreateNewSimilarValue(graph, cell_state);
124 {
125 // #1 matrix multiplication: input_weights * input_tensor
126 // If has no normalization, also adds bias.
127 Node* node = graph->NewNode();
128 node->operation.type = ToString(OperationType::FULLY_CONNECTED);
129 FullyConnectedAttributes fc_attr;
130 RETURN_IF_ERROR(
131 SetFullyConnectedWeights(input_weight_id, reader, &fc_attr));
132 if (!has_normalization) {
133 RETURN_IF_ERROR(reader->ReadTensor(bias_id, &(fc_attr.bias)));
134 }
135 node->operation.attributes = std::move(fc_attr);
136 RETURN_IF_ERROR(
137 reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor));
138 RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id));
139 }
140
141 Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state);
142 {
143 // #2 matrix multiplication: recurrent_weights * output_state
144 Node* node = graph->NewNode();
145 node->operation.type = ToString(OperationType::FULLY_CONNECTED);
146 FullyConnectedAttributes fc_attr;
147 RETURN_IF_ERROR(
148 SetFullyConnectedWeights(recurrent_weight_id, reader, &fc_attr));
149 node->operation.attributes = std::move(fc_attr);
150 RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id));
151 RETURN_IF_ERROR(
152 graph->SetProducer(node->id, output_state_times_weights->id));
153 }
154
155 Value* cell_state_times_weights;
156 if (has_peephole) {
157 // #3 elementwise multiplication: cell_weight .* cell_state
158 cell_state_times_weights = CreateNewSimilarValue(graph, cell_state);
159 Node* node = graph->NewNode();
160 node->operation.type = ToString(OperationType::MUL);
161 ElementwiseAttributes attr;
162 Tensor<Linear, DataType::FLOAT32> weights;
163 RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights));
164 attr.param = std::move(weights);
165 node->operation.attributes = std::move(attr);
166 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
167 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id));
168 }
169
170 Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state);
171 Node* add_node = graph->NewNode();
172 {
173 // #4 elementwise addition: #1 + #2 + #3
174 add_node->operation.type = ToString(OperationType::ADD);
175 RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id));
176 RETURN_IF_ERROR(
177 graph->AddConsumer(add_node->id, output_state_times_weights->id));
178 if (has_peephole) {
179 RETURN_IF_ERROR(
180 graph->AddConsumer(add_node->id, cell_state_times_weights->id));
181 }
182 RETURN_IF_ERROR(
183 graph->SetProducer(add_node->id, gate_before_normalization->id));
184 }
185
186 if (!has_normalization) {
187 // #5 Activation function: activate(temp + bias)
188 // Bias is added in node #1.
189 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node));
190 *gate_out = gate_before_normalization;
191 return absl::OkStatus();
192 }
193
194 Value* normalized_gate =
195 CreateNewSimilarValue(graph, gate_before_normalization);
196 {
197 // #6 Normalization: normalize(temp)
198 Node* node = graph->NewNode();
199 node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION);
200 RETURN_IF_ERROR(
201 graph->AddConsumer(node->id, gate_before_normalization->id));
202 RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id));
203 }
204 Value* reweighted_normalized_gate =
205 CreateNewSimilarValue(graph, normalized_gate);
206 {
207 // #7 Elementwise multiplication: norm_weights .* #6
208 Node* node = graph->NewNode();
209 node->operation.type = ToString(OperationType::MUL);
210 ElementwiseAttributes attr;
211 Tensor<Linear, DataType::FLOAT32> norm_weights;
212 RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights));
213 attr.param = std::move(norm_weights);
214 node->operation.attributes = std::move(attr);
215 RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id));
216 RETURN_IF_ERROR(
217 graph->SetProducer(node->id, reweighted_normalized_gate->id));
218 }
219 Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate);
220 {
221 // #8 Elementwise add: #7 + bias
222 Node* node = graph->NewNode();
223 node->operation.type = ToString(OperationType::ADD);
224 ElementwiseAttributes attr;
225 Tensor<Linear, DataType::FLOAT32> bias;
226 RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias));
227 attr.param = std::move(bias);
228 node->operation.attributes = std::move(attr);
229 RETURN_IF_ERROR(
230 graph->AddConsumer(node->id, reweighted_normalized_gate->id));
231 RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id));
232
233 // #9: Activation function
234 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
235 }
236 *gate_out = gate;
237 return absl::OkStatus();
238 }
239
240 // Builds subgraph for LSTM cell state update.
241 // Returns a Value representing the updated cell state.
242 // High-level parameters:
243 // - clip: if > 0, clamp the resulting cell state to [-clip, +clip].
244 //
245 // Implements the following:
246 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
247 //
248 // cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate);
249 //
BuildCellStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * forget_gate,Value * input_gate,Value * cell_gate,float cell_clip,Value ** cell_state_new)250 absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader,
251 Value* forget_gate, Value* input_gate,
252 Value* cell_gate, float cell_clip,
253 Value** cell_state_new) {
254 Value* cell_state;
255 RETURN_IF_ERROR(reader->ReadValue(
256 tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state));
257 Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate);
258 {
259 // #1 elementwise multiplication: forget_gate .* cell_state
260 Node* node = graph->NewNode();
261 node->operation.type = ToString(OperationType::MUL);
262 RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
263 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
264 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id));
265 }
266 Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate);
267 {
268 // #2 elementwise multiplication: input_gate .* cell_gate
269 // Note, with CIFG input_gate is equal to 1-forget_gate.
270 Node* node = graph->NewNode();
271 node->operation.type = ToString(OperationType::MUL);
272 RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id));
273 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id));
274 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id));
275 }
276 Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate);
277 {
278 // #3 elementwise add: #1 + #2
279 Node* node = graph->NewNode();
280 node->operation.type = ToString(OperationType::ADD);
281 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id));
282 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id));
283 RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id));
284 }
285
286 if (cell_clip <= 0.0f) {
287 *cell_state_new = new_cell_state;
288 return absl::OkStatus();
289 }
290
291 Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state);
292 {
293 // #4 elementwise minimum: min(#3, clip)
294 Node* node = graph->NewNode();
295 node->operation.type = ToString(OperationType::MINIMUM);
296 ElementwiseAttributes attr;
297 attr.param = cell_clip;
298 node->operation.attributes = std::move(attr);
299 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id));
300 RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
301 }
302 Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state);
303 {
304 // #5 elementwise maximum: max(#4, -clip)
305 Node* node = graph->NewNode();
306 node->operation.type = ToString(OperationType::MAXIMUM);
307 ElementwiseAttributes attr;
308 attr.param = -cell_clip;
309 node->operation.attributes = std::move(attr);
310 RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
311 RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id));
312 }
313 *cell_state_new = clipped_cell_state;
314 return absl::OkStatus();
315 }
316
317 // Build subgraph for LSTM output state update.
318 // Returns value representing the updated output state.
319 // High-level parameters:
320 // - Has projection (if true, provide projection_weights).
321 // - Has projection bias (only with projection).
322 // - clip: clamp the projection output to [-clip, clip].
323 // - Which activation function to use.
324 // Note the updated output state does not depend on the old output state
325 // directly, only through the output gate.
326 //
327 // Implements the following:
328 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
329 //
330 // temp = output_gate .* activate(cell_state);
331 // if (projection):
332 // output_state_new = clip(projection_weights * temp + projection_bias);
333 // else:
334 // output_state_new = temp;
335 //
BuildOutputStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * output_gate,Value * cell_state,TfLiteFusedActivation activation,bool has_projection,float proj_clip,Value ** output_state_new)336 absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader,
337 Value* output_state, Value* output_gate,
338 Value* cell_state,
339 TfLiteFusedActivation activation,
340 bool has_projection, float proj_clip,
341 Value** output_state_new) {
342 Value* activated_state = CreateNewSimilarValue(graph, cell_state);
343 {
344 // #1 activation: activate(cell_state)
345 Node* node = graph->NewNode();
346 switch (activation) {
347 case kTfLiteActTanh:
348 node->operation.type = ToString(OperationType::TANH);
349 break;
350 case kTfLiteActSigmoid:
351 node->operation.type = ToString(OperationType::SIGMOID);
352 break;
353 default:
354 return absl::InvalidArgumentError(
355 absl::StrCat("Unsupported activation: ", activation));
356 }
357 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
358 RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id));
359 }
360
361 Value* new_output_state = CreateNewSimilarValue(graph, cell_state);
362 {
363 // #2 elementwise multiplication: output_gate .* #1
364 Node* node = graph->NewNode();
365 node->operation.type = ToString(OperationType::MUL);
366 RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id));
367 RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id));
368 RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id));
369 }
370
371 if (!has_projection) {
372 *output_state_new = new_output_state;
373 return absl::OkStatus();
374 }
375
376 Value* projected_output_state = CreateNewSimilarValue(graph, output_state);
377 {
378 // #3 matrix multiplication: projection_weights * #2 + projection_bias
379 Node* node = graph->NewNode();
380 FullyConnectedAttributes fc_attr;
381 RETURN_IF_ERROR(SetFullyConnectedWeights(
382 tflite::ops::builtin::lstm::full::kProjectionWeightsTensor, reader,
383 &fc_attr));
384 // Projection bias is optional
385 reader
386 ->ReadTensor(tflite::ops::builtin::lstm::full::kProjectionBiasTensor,
387 &(fc_attr.bias))
388 .IgnoreError();
389 node->operation.attributes = std::move(fc_attr);
390 node->operation.type = ToString(OperationType::FULLY_CONNECTED);
391 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
392 RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id));
393 }
394
395 if (proj_clip <= 0.0f) {
396 *output_state_new = projected_output_state;
397 return absl::OkStatus();
398 }
399
400 Value* max_clipped_state =
401 CreateNewSimilarValue(graph, projected_output_state);
402 {
403 // #4 elementwise minimum: min(#3, clip)
404 Node* node = graph->NewNode();
405 node->operation.type = ToString(OperationType::MINIMUM);
406 ElementwiseAttributes attr;
407 attr.param = proj_clip;
408 node->operation.attributes = std::move(attr);
409 RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id));
410 RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
411 }
412 Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state);
413 {
414 // #5 elementwise maximum: max(#4, -clip)
415 Node* node = graph->NewNode();
416 node->operation.type = ToString(OperationType::MAXIMUM);
417 ElementwiseAttributes attr;
418 attr.param = -proj_clip;
419 node->operation.attributes = std::move(attr);
420 RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
421 RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id));
422 }
423 *output_state_new = clipped_output_state;
424 return absl::OkStatus();
425 }
426
427 } // namespace
428
429 // Build subgraph for a single LSTM OP.
430 // Returns a mapping for the used variable tensors' updated Values.
431 //
432 // High-level parameters:
433 // - Has CIFG:
434 // If false, calculate input_gate regularly.
435 // If true, calculate input_gate to 1-forget_gate.
436 // - Has peephole: see BuildLstmGate. Applies to all gates.
437 // - Has normalization: see BuildLstmGate. Applies to all gates.
438 // - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate
439 // - Which activation to use:
440 // Applies to only cell gate and output state update.
441 // Other gates always use Sigmoid.
442 //
ParseLSTMAttributes(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * params,absl::flat_hash_map<int,ValueId> * new_variable_input_values)443 absl::Status ParseLSTMAttributes(
444 const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
445 GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
446 absl::flat_hash_map<int, ValueId>* new_variable_input_values) {
447 const bool has_cifg = HasCifg(tflite_node);
448 const bool has_peephole = HasPeephole(tflite_node);
449 const bool has_normalization = HasNormalization(tflite_node);
450 const bool has_projection = HasProjection(tflite_node);
451
452 Value* old_cell_state;
453 RETURN_IF_ERROR(reader->ReadValue(
454 tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state));
455
456 if (old_cell_state->tensor.shape.b != 1) {
457 return absl::InvalidArgumentError(
458 "Batched execution is not supported for LSTM");
459 }
460
461 Value* old_output_state;
462 RETURN_IF_ERROR(reader->ReadValue(
463 tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state));
464
465 Value* forget_gate;
466 RETURN_IF_ERROR(BuildLstmGate(
467 graph, reader, old_output_state, old_cell_state,
468 tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor,
469 tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
470 tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor,
471 tflite::ops::builtin::lstm::full::kForgetGateBiasTensor,
472 tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor,
473 kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate));
474
475 Value* input_gate;
476 if (has_cifg) {
477 // When using cifg, input_gate is computed as (1 - forget_gate).
478 Node* node = graph->NewNode();
479 input_gate = CreateNewSimilarValue(graph, forget_gate);
480
481 node->operation.type = ToString(OperationType::SUB);
482 ElementwiseAttributes attr;
483 attr.param = 1.0f;
484 attr.runtime_tensor_is_second = true;
485 node->operation.attributes = std::move(attr);
486 RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
487 RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id));
488 } else {
489 RETURN_IF_ERROR(BuildLstmGate(
490 graph, reader, old_output_state, old_cell_state,
491 tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor,
492 tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor,
493 tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor,
494 tflite::ops::builtin::lstm::full::kInputGateBiasTensor,
495 tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor,
496 kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate));
497 }
498
499 // Cell state will not have peephole connections to itself
500 Value* cell_gate;
501 RETURN_IF_ERROR(BuildLstmGate(
502 graph, reader, old_output_state, old_cell_state,
503 tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor,
504 tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
505 /*cell_weight_id=*/-1,
506 tflite::ops::builtin::lstm::full::kCellGateBiasTensor,
507 tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor,
508 params->activation, /*has_peephole=*/false, has_normalization,
509 &cell_gate));
510
511 Value* new_cell_state;
512 RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate,
513 cell_gate, params->cell_clip,
514 &new_cell_state));
515
516 Value* output_gate;
517 RETURN_IF_ERROR(BuildLstmGate(
518 graph, reader, old_output_state, new_cell_state,
519 tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor,
520 tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
521 tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor,
522 tflite::ops::builtin::lstm::full::kOutputGateBiasTensor,
523 tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor,
524 kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate));
525
526 Value* new_output_state;
527 RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state,
528 output_gate, new_cell_state,
529 params->activation, has_projection,
530 params->proj_clip, &new_output_state));
531
532 {
533 // Copy updated output state to output.
534 Node* node = graph->NewNode();
535 node->operation.type = ToString(OperationType::COPY);
536 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
537 RETURN_IF_ERROR(reader->AddOutput(
538 node, tflite::ops::builtin::lstm::full::kOutputTensor));
539 }
540
541 new_variable_input_values->clear();
542 new_variable_input_values->emplace(
543 tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id);
544 new_variable_input_values->emplace(
545 tflite::ops::builtin::lstm::full::kOutputStateTensor,
546 new_output_state->id);
547 return absl::OkStatus();
548 }
549
550 } // namespace gpu
551 } // namespace tflite
552