1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "BidirectionalSequenceLSTM.h"
18 
19 #include "CpuExecutor.h"
20 #include "CpuOperationUtils.h"
21 #include "HalInterfaces.h"
22 #include "OperationsUtils.h"
23 
24 #include "Tracing.h"
25 
26 namespace android {
27 namespace nn {
28 
29 namespace {
30 
31 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)32 inline T* GetBuffer(RunTimeOperandInfo* operand) {
33     return reinterpret_cast<T*>(operand->buffer);
34 }
35 
36 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)37 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
38     return reinterpret_cast<const T*>(operand->buffer);
39 }
40 
41 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)42 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
43     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
44 }
45 
46 }  // anonymous namespace
47 
BidirectionalSequenceLSTM(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)48 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
49                                                      std::vector<RunTimeOperandInfo>& operands) {
50     input_ = GetInput(operation, operands, kInputTensor);
51 
52     fw_input_to_input_weights_ =
53             GetInput(operation, operands, kFwInputToInputWeightsTensor);  // optional
54     fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
55     fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
56     fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
57 
58     fw_recurrent_to_input_weights_ =
59             GetInput(operation, operands, kFwRecurrentToInputWeightsTensor);  // optional
60     fw_recurrent_to_forget_weights_ =
61             GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
62     fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
63     fw_recurrent_to_output_weights_ =
64             GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
65 
66     fw_cell_to_input_weights_ =
67             GetInput(operation, operands, kFwCellToInputWeightsTensor);  // optional
68     fw_cell_to_forget_weights_ =
69             GetInput(operation, operands, kFwCellToForgetWeightsTensor);  // optional
70     fw_cell_to_output_weights_ =
71             GetInput(operation, operands, kFwCellToOutputWeightsTensor);  // optional
72 
73     fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
74     fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
75     fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
76     fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
77 
78     fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor);  // optional
79     fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor);        // optional
80 
81     fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
82     fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
83 
84     bw_input_to_input_weights_ =
85             GetInput(operation, operands, kBwInputToInputWeightsTensor);  // optional
86     bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
87     bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
88     bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
89 
90     bw_recurrent_to_input_weights_ =
91             GetInput(operation, operands, kBwRecurrentToInputWeightsTensor);  // optional
92     bw_recurrent_to_forget_weights_ =
93             GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
94     bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
95     bw_recurrent_to_output_weights_ =
96             GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
97 
98     bw_cell_to_input_weights_ =
99             GetInput(operation, operands, kBwCellToInputWeightsTensor);  // optional
100     bw_cell_to_forget_weights_ =
101             GetInput(operation, operands, kBwCellToForgetWeightsTensor);  // optional
102     bw_cell_to_output_weights_ =
103             GetInput(operation, operands, kBwCellToOutputWeightsTensor);  // optional
104 
105     bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
106     bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
107     bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
108     bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
109 
110     bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor);  // optional
111     bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor);        // optional
112 
113     bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
114     bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
115 
116     aux_input_ = GetInput(operation, operands, kAuxInputTensor);
117     fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
118     fw_aux_input_to_forget_weights_ =
119             GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
120     fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
121     fw_aux_input_to_output_weights_ =
122             GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
123     bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
124     bw_aux_input_to_forget_weights_ =
125             GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
126     bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
127     bw_aux_input_to_output_weights_ =
128             GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
129 
130     fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
131     fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
132     fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
133     fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
134     bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
135     bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
136     bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
137     bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
138 
139     params_.activation = static_cast<TfLiteFusedActivation>(
140             getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
141     if (input_->type == OperandType::TENSOR_FLOAT32) {
142         params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
143         params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
144     } else {
145         params_.cell_clip = static_cast<float>(
146                 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
147         params_.proj_clip = static_cast<float>(
148                 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
149     }
150     params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
151     params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
152     params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
153 
154     fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
155     if (!params_.merge_outputs) {
156         bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
157     }
158 }
159 
Prepare(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,Shape * fwOutputShape,Shape * bwOutputShape)160 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation,
161                                         std::vector<RunTimeOperandInfo>& operands,
162                                         Shape* fwOutputShape, Shape* bwOutputShape) {
163     // Inferring batch size, number of outputs and number of cells from the
164     // input tensors.
165     NN_CHECK(NumDimensions(input_) == 3);
166     const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
167     const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
168     const uint32_t n_input = SizeOfDimension(input_, 2);
169 
170     const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
171     NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
172     NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input);
173 
174     NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
175     NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
176     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
177 
178     // Check that input tensor dimensions matches with each other.
179     if (!LSTMCell::CheckInputTensorDimensions(
180                 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
181                 fw_input_to_cell_weights_, fw_input_to_output_weights_,
182                 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
183                 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
184                 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
185                 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
186                 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
187                 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
188                 fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, &params_)) {
189         return false;
190     }
191 
192     const bool aux_inputs_all_or_none =
193             (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) &&
194              !IsNullInput(fw_aux_input_to_forget_weights_) &&
195              !IsNullInput(fw_aux_input_to_output_weights_) &&
196              !IsNullInput(bw_aux_input_to_cell_weights_) &&
197              !IsNullInput(bw_aux_input_to_forget_weights_) &&
198              !IsNullInput(bw_aux_input_to_output_weights_)) ||
199             (IsNullInput(fw_aux_input_to_cell_weights_) &&
200              IsNullInput(fw_aux_input_to_forget_weights_) &&
201              IsNullInput(fw_aux_input_to_output_weights_) &&
202              IsNullInput(bw_aux_input_to_cell_weights_) &&
203              IsNullInput(bw_aux_input_to_forget_weights_) &&
204              IsNullInput(bw_aux_input_to_output_weights_));
205     NN_CHECK(aux_inputs_all_or_none);
206     if (!IsNullInput(aux_input_)) {
207         // Check that aux_input has the same dimensions (except last) as the input.
208         NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
209         NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
210     }
211 
212     const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
213     NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2);
214     NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input);
215 
216     NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
217     NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
218     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
219 
220     const Shape& inputShape = input_->shape();
221     fwOutputShape->type = inputShape.type;
222     fwOutputShape->offset = inputShape.offset;
223     fwOutputShape->scale = inputShape.scale;
224     fwOutputShape->dimensions.resize(3);
225     fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
226     fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
227     fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
228 
229     // Check that input tensor dimensions matches with each other.
230     if (!LSTMCell::CheckInputTensorDimensions(
231                 input_, bw_input_to_input_weights_, bw_input_to_forget_weights_,
232                 bw_input_to_cell_weights_, bw_input_to_output_weights_,
233                 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
234                 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
235                 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
236                 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
237                 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
238                 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
239                 bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, &params_)) {
240         return false;
241     }
242 
243     if (!params_.merge_outputs) {
244         bwOutputShape->type = inputShape.type;
245         bwOutputShape->offset = inputShape.offset;
246         bwOutputShape->scale = inputShape.scale;
247         bwOutputShape->dimensions.resize(3);
248         bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
249         bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
250         bwOutputShape->dimensions[2] = n_bw_output;
251     }
252 
253     if (params_.use_cifg) {
254         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
255         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
256     } else {
257         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
258         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
259     }
260     fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
261     fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
262     fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
263 
264     return true;
265 }
266 
Eval()267 bool BidirectionalSequenceLSTM::Eval() {
268     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
269     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
270     std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
271     fw_output_dims[2] = n_fw_output;
272     std::vector<uint32_t> bw_output_dims = fw_output_dims;
273     bw_output_dims[2] = n_bw_output;
274     const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
275     const uint32_t n_output_elements =
276             fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
277 
278     switch (input_->type) {
279         case OperandType::TENSOR_FLOAT32: {
280             std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
281             const bool kForwardSequence = true;
282             LSTMCell::LSTMEvalFloat32(
283                     params_, GetBuffer<const float>(input_), input_->shape(),
284                     GetBuffer<const float>(fw_input_to_input_weights_),
285                     GetBuffer<const float>(fw_input_to_forget_weights_),
286                     GetBuffer<const float>(fw_input_to_cell_weights_),
287                     GetBuffer<const float>(fw_input_to_output_weights_),
288                     fw_input_to_output_weights_->shape(),
289                     GetBuffer<const float>(fw_recurrent_to_input_weights_),
290                     GetBuffer<const float>(fw_recurrent_to_forget_weights_),
291                     GetBuffer<const float>(fw_recurrent_to_cell_weights_),
292                     GetBuffer<const float>(fw_recurrent_to_output_weights_),
293                     fw_recurrent_to_output_weights_->shape(),
294                     GetBuffer<const float>(fw_cell_to_input_weights_),
295                     GetBuffer<const float>(fw_cell_to_forget_weights_),
296                     GetBuffer<const float>(fw_cell_to_output_weights_),
297                     GetOptionalBuffer<const float>(aux_input_),
298                     GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
299                     GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
300                     GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
301                     GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
302                     GetBuffer<const float>(fw_input_gate_bias_),
303                     GetBuffer<const float>(fw_forget_gate_bias_),
304                     GetBuffer<const float>(fw_cell_bias_),
305                     GetBuffer<const float>(fw_output_gate_bias_),
306                     GetBuffer<const float>(fw_projection_weights_),
307                     GetBuffer<const float>(fw_projection_bias_),
308                     GetBuffer<const float>(fw_activation_state_),
309                     GetBuffer<const float>(fw_cell_state_),
310                     GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
311                     GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
312                     GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
313                     GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
314                     GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_),
315                     GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
316                     kForwardSequence);
317 
318             std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
319             const bool kBackwardSequence = false;
320             LSTMCell::LSTMEvalFloat32(
321                     params_, GetBuffer<const float>(input_), input_->shape(),
322                     GetBuffer<const float>(bw_input_to_input_weights_),
323                     GetBuffer<const float>(bw_input_to_forget_weights_),
324                     GetBuffer<const float>(bw_input_to_cell_weights_),
325                     GetBuffer<const float>(bw_input_to_output_weights_),
326                     bw_input_to_output_weights_->shape(),
327                     GetBuffer<const float>(bw_recurrent_to_input_weights_),
328                     GetBuffer<const float>(bw_recurrent_to_forget_weights_),
329                     GetBuffer<const float>(bw_recurrent_to_cell_weights_),
330                     GetBuffer<const float>(bw_recurrent_to_output_weights_),
331                     bw_recurrent_to_output_weights_->shape(),
332                     GetBuffer<const float>(bw_cell_to_input_weights_),
333                     GetBuffer<const float>(bw_cell_to_forget_weights_),
334                     GetBuffer<const float>(bw_cell_to_output_weights_),
335                     GetOptionalBuffer<const float>(aux_input_),
336                     GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
337                     GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
338                     GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
339                     GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
340                     GetBuffer<const float>(bw_input_gate_bias_),
341                     GetBuffer<const float>(bw_forget_gate_bias_),
342                     GetBuffer<const float>(bw_cell_bias_),
343                     GetBuffer<const float>(bw_output_gate_bias_),
344                     GetBuffer<const float>(bw_projection_weights_),
345                     GetBuffer<const float>(bw_projection_bias_),
346                     GetBuffer<const float>(bw_activation_state_),
347                     GetBuffer<const float>(bw_cell_state_),
348                     GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
349                     GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
350                     GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
351                     GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
352                     GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_),
353                     params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
354                                           : GetBuffer<float>(bw_output_),
355                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
356             if (params_.merge_outputs) {
357                 std::vector<float> temp(n_output_elements);
358                 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
359                                     GetBuffer<float>(fw_output_) + n_fw_output_elements,
360                                     bw_output_dims, temp.data());
361                 std::copy(temp.data(), temp.data() + n_output_elements,
362                           GetBuffer<float>(fw_output_));
363             }
364         } break;
365         case OperandType::TENSOR_FLOAT16: {
366             std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
367             const bool kForwardSequence = true;
368             LSTMCell::LSTMEvalFloat16(
369                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
370                     GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
371                     GetBuffer<const _Float16>(fw_input_to_forget_weights_),
372                     GetBuffer<const _Float16>(fw_input_to_cell_weights_),
373                     GetBuffer<const _Float16>(fw_input_to_output_weights_),
374                     fw_input_to_output_weights_->shape(),
375                     GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
376                     GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
377                     GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
378                     GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
379                     fw_recurrent_to_output_weights_->shape(),
380                     GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
381                     GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
382                     GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_),
383                     GetOptionalBuffer<const _Float16>(aux_input_),
384                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
385                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
386                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
387                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
388                     GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
389                     GetBuffer<const _Float16>(fw_forget_gate_bias_),
390                     GetBuffer<const _Float16>(fw_cell_bias_),
391                     GetBuffer<const _Float16>(fw_output_gate_bias_),
392                     GetOptionalBuffer<const _Float16>(fw_projection_weights_),
393                     GetOptionalBuffer<const _Float16>(fw_projection_bias_),
394                     GetBuffer<const _Float16>(fw_activation_state_),
395                     GetBuffer<const _Float16>(fw_cell_state_),
396                     GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
397                     GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
398                     GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
399                     GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
400                     GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_),
401                     GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
402                     kForwardSequence);
403 
404             std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
405             const bool kBackwardSequence = false;
406             LSTMCell::LSTMEvalFloat16(
407                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
408                     GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
409                     GetBuffer<const _Float16>(bw_input_to_forget_weights_),
410                     GetBuffer<const _Float16>(bw_input_to_cell_weights_),
411                     GetBuffer<const _Float16>(bw_input_to_output_weights_),
412                     bw_input_to_output_weights_->shape(),
413                     GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
414                     GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
415                     GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
416                     GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
417                     bw_recurrent_to_output_weights_->shape(),
418                     GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
419                     GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
420                     GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_),
421                     GetOptionalBuffer<const _Float16>(aux_input_),
422                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
423                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
424                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
425                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
426                     GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
427                     GetBuffer<const _Float16>(bw_forget_gate_bias_),
428                     GetBuffer<const _Float16>(bw_cell_bias_),
429                     GetBuffer<const _Float16>(bw_output_gate_bias_),
430                     GetOptionalBuffer<const _Float16>(bw_projection_weights_),
431                     GetOptionalBuffer<const _Float16>(bw_projection_bias_),
432                     GetBuffer<const _Float16>(bw_activation_state_),
433                     GetBuffer<const _Float16>(bw_cell_state_),
434                     GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
435                     GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
436                     GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
437                     GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
438                     GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_),
439                     params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
440                                           : GetBuffer<_Float16>(bw_output_),
441                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
442             if (params_.merge_outputs) {
443                 std::vector<_Float16> temp(n_output_elements);
444                 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
445                                     GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
446                                     bw_output_dims, temp.data());
447                 std::copy(temp.data(), temp.data() + n_output_elements,
448                           GetBuffer<_Float16>(fw_output_));
449             }
450         } break;
451         default: {
452             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
453             return false;
454         }
455     }
456     return true;
457 }
458 
459 }  // namespace nn
460 }  // namespace android
461