1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
17 
18 #include <optional>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
30 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
31 #include "tensorflow/lite/delegates/gpu/common/operations.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/tensor.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/internal/types.h"
39 #include "tensorflow/lite/kernels/lstm_shared.h"
40 #include "tensorflow/lite/string_type.h"
41 
42 namespace tflite {
43 namespace gpu {
44 namespace {
45 
CreateNewSimilarValue(GraphFloat32 * graph,const Value * old_value)46 Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) {
47   Value* new_value = graph->NewValue();
48   new_value->quant_params = old_value->quant_params;
49   new_value->tensor.shape = old_value->tensor.shape;
50   new_value->tensor.type = old_value->tensor.type;
51   new_value->tensor.ref = -1;
52   return new_value;
53 }
54 
SetFullyConnectedWeights(int weights_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)55 absl::Status SetFullyConnectedWeights(int weights_tensor_id,
56                                       ObjectReader* reader,
57                                       FullyConnectedAttributes* attr) {
58   Tensor<HW, DataType::FLOAT32> weights;
59   RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
60   attr->weights.data = std::move(weights.data);
61   attr->weights.id = weights.id;
62   attr->weights.shape.o = weights.shape.h;
63   attr->weights.shape.h = 1;
64   attr->weights.shape.w = 1;
65   attr->weights.shape.i = weights.shape.w;
66   return absl::OkStatus();
67 }
68 
HasTensor(const TfLiteNode * node,const int index)69 bool HasTensor(const TfLiteNode* node, const int index) {
70   return (index < node->inputs->size) &&
71          (node->inputs->data[index] != kTfLiteOptionalTensor);
72 }
73 
HasCifg(const TfLiteNode * node)74 bool HasCifg(const TfLiteNode* node) {
75   return !HasTensor(
76       node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor);
77 }
78 
HasPeephole(const TfLiteNode * node)79 bool HasPeephole(const TfLiteNode* node) {
80   // Use forget weights to detect peephole instead of input weights as input
81   // weights may be missing for cifg.
82   return HasTensor(
83       node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor);
84 }
85 
HasNormalization(const TfLiteNode * node)86 bool HasNormalization(const TfLiteNode* node) {
87   return HasTensor(
88       node,
89       tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
90 }
91 
HasProjection(const TfLiteNode * node)92 bool HasProjection(const TfLiteNode* node) {
93   return HasTensor(node,
94                    tflite::ops::builtin::lstm::full::kProjectionWeightsTensor);
95 }
96 
97 // Builds subgraph for a single LSTM gate.
98 // Returns a Value representing the gate's output.
99 // High-level parameters:
100 //   - Has normalization (if true: provide normalization weights).
101 //   - Has peephole connection (if true: provide peephole weights).
102 //   - Which activation function to use.
103 // Note: no support for aux input.
104 //
105 // Implements the following:
106 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
107 //   temp = input_weights * input_tensor + recurrent_weights * output_state;
108 //   if (peephole):
109 //     temp += peephole_weights .* cell_state;
110 //   if (layer normalization):
111 //     gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias);
112 //   else:
113 //     gate = activate(temp + bias);
114 //
BuildLstmGate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * cell_state,int input_weight_id,int recurrent_weight_id,int cell_weight_id,int bias_id,int normalization_weight_id,const TfLiteFusedActivation activation,bool has_peephole,bool has_normalization,Value ** gate_out)115 absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader,
116                            Value* output_state, Value* cell_state,
117                            int input_weight_id, int recurrent_weight_id,
118                            int cell_weight_id, int bias_id,
119                            int normalization_weight_id,
120                            const TfLiteFusedActivation activation,
121                            bool has_peephole, bool has_normalization,
122                            Value** gate_out) {
123   Value* input_times_weights = CreateNewSimilarValue(graph, cell_state);
124   {
125     // #1 matrix multiplication: input_weights * input_tensor
126     // If has no normalization, also adds bias.
127     Node* node = graph->NewNode();
128     node->operation.type = ToString(OperationType::FULLY_CONNECTED);
129     FullyConnectedAttributes fc_attr;
130     RETURN_IF_ERROR(
131         SetFullyConnectedWeights(input_weight_id, reader, &fc_attr));
132     if (!has_normalization) {
133       RETURN_IF_ERROR(reader->ReadTensor(bias_id, &(fc_attr.bias)));
134     }
135     node->operation.attributes = std::move(fc_attr);
136     RETURN_IF_ERROR(
137         reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor));
138     RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id));
139   }
140 
141   Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state);
142   {
143     // #2 matrix multiplication: recurrent_weights * output_state
144     Node* node = graph->NewNode();
145     node->operation.type = ToString(OperationType::FULLY_CONNECTED);
146     FullyConnectedAttributes fc_attr;
147     RETURN_IF_ERROR(
148         SetFullyConnectedWeights(recurrent_weight_id, reader, &fc_attr));
149     node->operation.attributes = std::move(fc_attr);
150     RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id));
151     RETURN_IF_ERROR(
152         graph->SetProducer(node->id, output_state_times_weights->id));
153   }
154 
155   Value* cell_state_times_weights;
156   if (has_peephole) {
157     // #3 elementwise multiplication: cell_weight .* cell_state
158     cell_state_times_weights = CreateNewSimilarValue(graph, cell_state);
159     Node* node = graph->NewNode();
160     node->operation.type = ToString(OperationType::MUL);
161     ElementwiseAttributes attr;
162     Tensor<Linear, DataType::FLOAT32> weights;
163     RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights));
164     attr.param = std::move(weights);
165     node->operation.attributes = std::move(attr);
166     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
167     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id));
168   }
169 
170   Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state);
171   Node* add_node = graph->NewNode();
172   {
173     // #4 elementwise addition: #1 + #2 + #3
174     add_node->operation.type = ToString(OperationType::ADD);
175     RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id));
176     RETURN_IF_ERROR(
177         graph->AddConsumer(add_node->id, output_state_times_weights->id));
178     if (has_peephole) {
179       RETURN_IF_ERROR(
180           graph->AddConsumer(add_node->id, cell_state_times_weights->id));
181     }
182     RETURN_IF_ERROR(
183         graph->SetProducer(add_node->id, gate_before_normalization->id));
184   }
185 
186   if (!has_normalization) {
187     // #5 Activation function: activate(temp + bias)
188     // Bias is added in node #1.
189     RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node));
190     *gate_out = gate_before_normalization;
191     return absl::OkStatus();
192   }
193 
194   Value* normalized_gate =
195       CreateNewSimilarValue(graph, gate_before_normalization);
196   {
197     // #6 Normalization: normalize(temp)
198     Node* node = graph->NewNode();
199     node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION);
200     RETURN_IF_ERROR(
201         graph->AddConsumer(node->id, gate_before_normalization->id));
202     RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id));
203   }
204   Value* reweighted_normalized_gate =
205       CreateNewSimilarValue(graph, normalized_gate);
206   {
207     // #7 Elementwise multiplication: norm_weights .* #6
208     Node* node = graph->NewNode();
209     node->operation.type = ToString(OperationType::MUL);
210     ElementwiseAttributes attr;
211     Tensor<Linear, DataType::FLOAT32> norm_weights;
212     RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights));
213     attr.param = std::move(norm_weights);
214     node->operation.attributes = std::move(attr);
215     RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id));
216     RETURN_IF_ERROR(
217         graph->SetProducer(node->id, reweighted_normalized_gate->id));
218   }
219   Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate);
220   {
221     // #8 Elementwise add: #7 + bias
222     Node* node = graph->NewNode();
223     node->operation.type = ToString(OperationType::ADD);
224     ElementwiseAttributes attr;
225     Tensor<Linear, DataType::FLOAT32> bias;
226     RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias));
227     attr.param = std::move(bias);
228     node->operation.attributes = std::move(attr);
229     RETURN_IF_ERROR(
230         graph->AddConsumer(node->id, reweighted_normalized_gate->id));
231     RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id));
232 
233     // #9: Activation function
234     RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
235   }
236   *gate_out = gate;
237   return absl::OkStatus();
238 }
239 
240 // Builds subgraph for LSTM cell state update.
241 // Returns a Value representing the updated cell state.
242 // High-level parameters:
243 //  - clip: if > 0, clamp the resulting cell state to [-clip, +clip].
244 //
245 // Implements the following:
246 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
247 //
248 //   cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate);
249 //
BuildCellStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * forget_gate,Value * input_gate,Value * cell_gate,float cell_clip,Value ** cell_state_new)250 absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader,
251                                   Value* forget_gate, Value* input_gate,
252                                   Value* cell_gate, float cell_clip,
253                                   Value** cell_state_new) {
254   Value* cell_state;
255   RETURN_IF_ERROR(reader->ReadValue(
256       tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state));
257   Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate);
258   {
259     // #1 elementwise multiplication: forget_gate .* cell_state
260     Node* node = graph->NewNode();
261     node->operation.type = ToString(OperationType::MUL);
262     RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
263     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
264     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id));
265   }
266   Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate);
267   {
268     // #2 elementwise multiplication: input_gate .* cell_gate
269     // Note, with CIFG input_gate is equal to 1-forget_gate.
270     Node* node = graph->NewNode();
271     node->operation.type = ToString(OperationType::MUL);
272     RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id));
273     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id));
274     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id));
275   }
276   Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate);
277   {
278     // #3 elementwise add: #1 + #2
279     Node* node = graph->NewNode();
280     node->operation.type = ToString(OperationType::ADD);
281     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id));
282     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id));
283     RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id));
284   }
285 
286   if (cell_clip <= 0.0f) {
287     *cell_state_new = new_cell_state;
288     return absl::OkStatus();
289   }
290 
291   Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state);
292   {
293     // #4 elementwise minimum: min(#3, clip)
294     Node* node = graph->NewNode();
295     node->operation.type = ToString(OperationType::MINIMUM);
296     ElementwiseAttributes attr;
297     attr.param = cell_clip;
298     node->operation.attributes = std::move(attr);
299     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id));
300     RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
301   }
302   Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state);
303   {
304     // #5 elementwise maximum: max(#4, -clip)
305     Node* node = graph->NewNode();
306     node->operation.type = ToString(OperationType::MAXIMUM);
307     ElementwiseAttributes attr;
308     attr.param = -cell_clip;
309     node->operation.attributes = std::move(attr);
310     RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
311     RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id));
312   }
313   *cell_state_new = clipped_cell_state;
314   return absl::OkStatus();
315 }
316 
317 // Build subgraph for LSTM output state update.
318 // Returns value representing the updated output state.
319 // High-level parameters:
320 //   - Has projection (if true, provide projection_weights).
321 //   - Has projection bias (only with projection).
322 //   - clip: clamp the projection output to [-clip, clip].
323 //   - Which activation function to use.
324 // Note the updated output state does not depend on the old output state
325 // directly, only through the output gate.
326 //
327 // Implements the following:
328 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
329 //
330 //   temp = output_gate .* activate(cell_state);
331 //   if (projection):
332 //     output_state_new = clip(projection_weights * temp + projection_bias);
333 //   else:
334 //     output_state_new = temp;
335 //
BuildOutputStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * output_gate,Value * cell_state,TfLiteFusedActivation activation,bool has_projection,float proj_clip,Value ** output_state_new)336 absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader,
337                                     Value* output_state, Value* output_gate,
338                                     Value* cell_state,
339                                     TfLiteFusedActivation activation,
340                                     bool has_projection, float proj_clip,
341                                     Value** output_state_new) {
342   Value* activated_state = CreateNewSimilarValue(graph, cell_state);
343   {
344     // #1 activation: activate(cell_state)
345     Node* node = graph->NewNode();
346     switch (activation) {
347       case kTfLiteActTanh:
348         node->operation.type = ToString(OperationType::TANH);
349         break;
350       case kTfLiteActSigmoid:
351         node->operation.type = ToString(OperationType::SIGMOID);
352         break;
353       default:
354         return absl::InvalidArgumentError(
355             absl::StrCat("Unsupported activation: ", activation));
356     }
357     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
358     RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id));
359   }
360 
361   Value* new_output_state = CreateNewSimilarValue(graph, cell_state);
362   {
363     // #2 elementwise multiplication: output_gate .* #1
364     Node* node = graph->NewNode();
365     node->operation.type = ToString(OperationType::MUL);
366     RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id));
367     RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id));
368     RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id));
369   }
370 
371   if (!has_projection) {
372     *output_state_new = new_output_state;
373     return absl::OkStatus();
374   }
375 
376   Value* projected_output_state = CreateNewSimilarValue(graph, output_state);
377   {
378     // #3 matrix multiplication: projection_weights * #2 + projection_bias
379     Node* node = graph->NewNode();
380     FullyConnectedAttributes fc_attr;
381     RETURN_IF_ERROR(SetFullyConnectedWeights(
382         tflite::ops::builtin::lstm::full::kProjectionWeightsTensor, reader,
383         &fc_attr));
384     // Projection bias is optional
385     reader
386         ->ReadTensor(tflite::ops::builtin::lstm::full::kProjectionBiasTensor,
387                      &(fc_attr.bias))
388         .IgnoreError();
389     node->operation.attributes = std::move(fc_attr);
390     node->operation.type = ToString(OperationType::FULLY_CONNECTED);
391     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
392     RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id));
393   }
394 
395   if (proj_clip <= 0.0f) {
396     *output_state_new = projected_output_state;
397     return absl::OkStatus();
398   }
399 
400   Value* max_clipped_state =
401       CreateNewSimilarValue(graph, projected_output_state);
402   {
403     // #4 elementwise minimum: min(#3, clip)
404     Node* node = graph->NewNode();
405     node->operation.type = ToString(OperationType::MINIMUM);
406     ElementwiseAttributes attr;
407     attr.param = proj_clip;
408     node->operation.attributes = std::move(attr);
409     RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id));
410     RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
411   }
412   Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state);
413   {
414     // #5 elementwise maximum: max(#4, -clip)
415     Node* node = graph->NewNode();
416     node->operation.type = ToString(OperationType::MAXIMUM);
417     ElementwiseAttributes attr;
418     attr.param = -proj_clip;
419     node->operation.attributes = std::move(attr);
420     RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
421     RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id));
422   }
423   *output_state_new = clipped_output_state;
424   return absl::OkStatus();
425 }
426 
427 }  // namespace
428 
429 // Build subgraph for a single LSTM OP.
430 // Returns a mapping for the used variable tensors' updated Values.
431 //
432 // High-level parameters:
433 //   - Has CIFG:
434 //       If false, calculate input_gate regularly.
435 //       If true, calculate input_gate to 1-forget_gate.
436 //   - Has peephole: see BuildLstmGate. Applies to all gates.
437 //   - Has normalization: see BuildLstmGate. Applies to all gates.
438 //   - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate
439 //   - Which activation to use:
440 //       Applies to only cell gate and output state update.
441 //       Other gates always use Sigmoid.
442 //
ParseLSTMAttributes(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * params,absl::flat_hash_map<int,ValueId> * new_variable_input_values)443 absl::Status ParseLSTMAttributes(
444     const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
445     GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
446     absl::flat_hash_map<int, ValueId>* new_variable_input_values) {
447   const bool has_cifg = HasCifg(tflite_node);
448   const bool has_peephole = HasPeephole(tflite_node);
449   const bool has_normalization = HasNormalization(tflite_node);
450   const bool has_projection = HasProjection(tflite_node);
451 
452   Value* old_cell_state;
453   RETURN_IF_ERROR(reader->ReadValue(
454       tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state));
455 
456   if (old_cell_state->tensor.shape.b != 1) {
457     return absl::InvalidArgumentError(
458         "Batched execution is not supported for LSTM");
459   }
460 
461   Value* old_output_state;
462   RETURN_IF_ERROR(reader->ReadValue(
463       tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state));
464 
465   Value* forget_gate;
466   RETURN_IF_ERROR(BuildLstmGate(
467       graph, reader, old_output_state, old_cell_state,
468       tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor,
469       tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
470       tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor,
471       tflite::ops::builtin::lstm::full::kForgetGateBiasTensor,
472       tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor,
473       kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate));
474 
475   Value* input_gate;
476   if (has_cifg) {
477     // When using cifg, input_gate is computed as (1 - forget_gate).
478     Node* node = graph->NewNode();
479     input_gate = CreateNewSimilarValue(graph, forget_gate);
480 
481     node->operation.type = ToString(OperationType::SUB);
482     ElementwiseAttributes attr;
483     attr.param = 1.0f;
484     attr.runtime_tensor_is_second = true;
485     node->operation.attributes = std::move(attr);
486     RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
487     RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id));
488   } else {
489     RETURN_IF_ERROR(BuildLstmGate(
490         graph, reader, old_output_state, old_cell_state,
491         tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor,
492         tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor,
493         tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor,
494         tflite::ops::builtin::lstm::full::kInputGateBiasTensor,
495         tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor,
496         kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate));
497   }
498 
499   // Cell state will not have peephole connections to itself
500   Value* cell_gate;
501   RETURN_IF_ERROR(BuildLstmGate(
502       graph, reader, old_output_state, old_cell_state,
503       tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor,
504       tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
505       /*cell_weight_id=*/-1,
506       tflite::ops::builtin::lstm::full::kCellGateBiasTensor,
507       tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor,
508       params->activation, /*has_peephole=*/false, has_normalization,
509       &cell_gate));
510 
511   Value* new_cell_state;
512   RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate,
513                                        cell_gate, params->cell_clip,
514                                        &new_cell_state));
515 
516   Value* output_gate;
517   RETURN_IF_ERROR(BuildLstmGate(
518       graph, reader, old_output_state, new_cell_state,
519       tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor,
520       tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
521       tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor,
522       tflite::ops::builtin::lstm::full::kOutputGateBiasTensor,
523       tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor,
524       kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate));
525 
526   Value* new_output_state;
527   RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state,
528                                          output_gate, new_cell_state,
529                                          params->activation, has_projection,
530                                          params->proj_clip, &new_output_state));
531 
532   {
533     // Copy updated output state to output.
534     Node* node = graph->NewNode();
535     node->operation.type = ToString(OperationType::COPY);
536     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
537     RETURN_IF_ERROR(reader->AddOutput(
538         node, tflite::ops::builtin::lstm::full::kOutputTensor));
539   }
540 
541   new_variable_input_values->clear();
542   new_variable_input_values->emplace(
543       tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id);
544   new_variable_input_values->emplace(
545       tflite::ops::builtin::lstm::full::kOutputStateTensor,
546       new_output_state->id);
547   return absl::OkStatus();
548 }
549 
550 }  // namespace gpu
551 }  // namespace tflite
552