1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "BidirectionalSequenceLSTM.h"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include "CpuExecutor.h"
25 #include "CpuOperationUtils.h"
26 #include "HalInterfaces.h"
27 #include "OperationsUtils.h"
28 #include "Tracing.h"
29 
30 namespace android {
31 namespace nn {
32 
33 namespace {
34 
35 using namespace hal;
36 
37 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)38 inline T* GetBuffer(RunTimeOperandInfo* operand) {
39     return reinterpret_cast<T*>(operand->buffer);
40 }
41 
42 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44     return reinterpret_cast<const T*>(operand->buffer);
45 }
46 
47 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50 }
51 
52 enum class LinkingMode {
53     NO_LINKING,
54     PARALLEL_LINKING,
55     CROSS_LINKING,
56 };
57 
getLinkingMode(bool hasAuxInput,bool hasAuxWeights,LinkingMode * linkingMode)58 bool getLinkingMode(bool hasAuxInput, bool hasAuxWeights, LinkingMode* linkingMode) {
59     // Three possible configurations for three possible linking modes:
60     // 1) NO_LINKING -- no auxiliary tensors at all
61     // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
62     //    input to the backward network, so the auxiliary weights are omitted.
63     // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
64     //    auxiliary weights.
65     if (!hasAuxInput && !hasAuxWeights) {
66         *linkingMode = LinkingMode::NO_LINKING;
67     } else if (hasAuxInput && !hasAuxWeights) {
68         *linkingMode = LinkingMode::PARALLEL_LINKING;
69     } else if (hasAuxInput && hasAuxWeights) {
70         *linkingMode = LinkingMode::CROSS_LINKING;
71     } else {
72         NN_RET_CHECK_FAIL()
73                 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
74     }
75 
76     return true;
77 }
78 
79 }  // anonymous namespace
80 
BidirectionalSequenceLSTM(const Operation & operation,RunTimeOperandInfo * operands)81 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
82                                                      RunTimeOperandInfo* operands) {
83     input_ = GetInput(operation, operands, kInputTensor);
84 
85     fw_input_to_input_weights_ =
86             GetInput(operation, operands, kFwInputToInputWeightsTensor);  // optional
87     fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
88     fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
89     fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
90 
91     fw_recurrent_to_input_weights_ =
92             GetInput(operation, operands, kFwRecurrentToInputWeightsTensor);  // optional
93     fw_recurrent_to_forget_weights_ =
94             GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
95     fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
96     fw_recurrent_to_output_weights_ =
97             GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
98 
99     fw_cell_to_input_weights_ =
100             GetInput(operation, operands, kFwCellToInputWeightsTensor);  // optional
101     fw_cell_to_forget_weights_ =
102             GetInput(operation, operands, kFwCellToForgetWeightsTensor);  // optional
103     fw_cell_to_output_weights_ =
104             GetInput(operation, operands, kFwCellToOutputWeightsTensor);  // optional
105 
106     fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
107     fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
108     fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
109     fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
110 
111     fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor);  // optional
112     fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor);        // optional
113 
114     fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
115     fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
116 
117     bw_input_to_input_weights_ =
118             GetInput(operation, operands, kBwInputToInputWeightsTensor);  // optional
119     bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
120     bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
121     bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
122 
123     bw_recurrent_to_input_weights_ =
124             GetInput(operation, operands, kBwRecurrentToInputWeightsTensor);  // optional
125     bw_recurrent_to_forget_weights_ =
126             GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
127     bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
128     bw_recurrent_to_output_weights_ =
129             GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
130 
131     bw_cell_to_input_weights_ =
132             GetInput(operation, operands, kBwCellToInputWeightsTensor);  // optional
133     bw_cell_to_forget_weights_ =
134             GetInput(operation, operands, kBwCellToForgetWeightsTensor);  // optional
135     bw_cell_to_output_weights_ =
136             GetInput(operation, operands, kBwCellToOutputWeightsTensor);  // optional
137 
138     bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
139     bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
140     bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
141     bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
142 
143     bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor);  // optional
144     bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor);        // optional
145 
146     bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
147     bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
148 
149     aux_input_ = GetInput(operation, operands, kAuxInputTensor);
150     fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
151     fw_aux_input_to_forget_weights_ =
152             GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
153     fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
154     fw_aux_input_to_output_weights_ =
155             GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
156     bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
157     bw_aux_input_to_forget_weights_ =
158             GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
159     bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
160     bw_aux_input_to_output_weights_ =
161             GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
162 
163     fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
164     fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
165     fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
166     fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
167     bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
168     bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
169     bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
170     bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
171 
172     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
173     params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
174             activationOperand, TfLiteFusedActivation::kTfLiteActNone));
175     const auto& clipOperand = *GetInput(operation, operands, kCellClipParam);
176     const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
177     if (input_->type == OperandType::TENSOR_FLOAT32) {
178         params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f);
179         params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f);
180     } else {
181         params_.cell_clip =
182                 static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f));
183         params_.proj_clip =
184                 static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f));
185     }
186     const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
187     params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false);
188     const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
189     params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false);
190     params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
191 
192     fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
193     if (!params_.merge_outputs) {
194         bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
195     }
196 
197     params_.output_state = (operation.outputs.size() == 5 || operation.outputs.size() == 6);
198     if (params_.output_state) {
199         uint32_t delta = params_.merge_outputs ? 1 : 0;
200         fw_output_activation_state_ =
201                 GetOutput(operation, operands, kFwOutputActivationStateTensor - delta);
202         fw_output_cell_state_ = GetOutput(operation, operands, kFwOutputCellStateTensor - delta);
203         bw_output_activation_state_ =
204                 GetOutput(operation, operands, kBwOutputActivationStateTensor - delta);
205         bw_output_cell_state_ = GetOutput(operation, operands, kBwOutputCellStateTensor - delta);
206     }
207 }
208 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * fwOutputShape,Shape * bwOutputShape,Shape * fwOutputActivationState,Shape * fwOutputCellState,Shape * bwOutputActivationState,Shape * bwOutputCellState)209 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
210                                         Shape* fwOutputShape, Shape* bwOutputShape,
211                                         Shape* fwOutputActivationState, Shape* fwOutputCellState,
212                                         Shape* bwOutputActivationState, Shape* bwOutputCellState) {
213     // Check we have all the inputs and outputs we need.
214     constexpr int requiredInputs[] = {
215             kInputTensor,
216             kFwInputToForgetWeightsTensor,
217             kFwInputToCellWeightsTensor,
218             kFwInputToOutputWeightsTensor,
219             kFwRecurrentToForgetWeightsTensor,
220             kFwRecurrentToCellWeightsTensor,
221             kFwRecurrentToOutputWeightsTensor,
222             kFwForgetGateBiasTensor,
223             kFwCellGateBiasTensor,
224             kFwOutputGateBiasTensor,
225             kBwInputToForgetWeightsTensor,
226             kBwInputToCellWeightsTensor,
227             kBwInputToOutputWeightsTensor,
228             kBwRecurrentToForgetWeightsTensor,
229             kBwRecurrentToCellWeightsTensor,
230             kBwRecurrentToOutputWeightsTensor,
231             kBwForgetGateBiasTensor,
232             kBwCellGateBiasTensor,
233             kBwOutputGateBiasTensor,
234             kFwInputActivationStateTensor,
235             kFwInputCellStateTensor,
236             kBwInputActivationStateTensor,
237             kBwInputCellStateTensor,
238             kActivationParam,
239             kCellClipParam,
240             kProjClipParam,
241             kMergeOutputsParam,
242             kTimeMajorParam,
243     };
244     for (const int requiredInput : requiredInputs) {
245         NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
246                 << "required input " << requiredInput << " is omitted";
247     }
248 
249     // Check that the scalar operands' buffers are large enough.
250     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
251     NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
252     const auto& cellOperand = *GetInput(operation, operands, kCellClipParam);
253     const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
254     if (input_->type == OperandType::TENSOR_FLOAT32) {
255         NN_RET_CHECK(cellOperand.length >= sizeof(float));
256         NN_RET_CHECK(projOperand.length >= sizeof(float));
257     } else {
258         NN_RET_CHECK(cellOperand.length >= sizeof(_Float16));
259         NN_RET_CHECK(projOperand.length >= sizeof(_Float16));
260     }
261     const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
262     NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool));
263     const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
264     NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool));
265 
266     // Inferring batch size, number of outputs and number of cells from the
267     // input tensors.
268     NN_CHECK(NumDimensions(input_) == 3);
269     const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
270     const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
271     const uint32_t n_fw_input = SizeOfDimension(input_, 2);
272 
273     const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
274     NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
275     NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_fw_input);
276 
277     NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
278     NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
279     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
280 
281     const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
282 
283     NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
284     NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
285     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
286 
287     // Check that input tensor dimensions matches with each other.
288     if (!LSTMCell::CheckInputTensorDimensions(
289                 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
290                 fw_input_to_cell_weights_, fw_input_to_output_weights_,
291                 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
292                 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
293                 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
294                 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
295                 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
296                 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
297                 fw_output_layer_norm_weights_, n_fw_input, n_fw_output, n_fw_cell, &params_)) {
298         return false;
299     }
300 
301     if (params_.use_cifg) {
302         NN_RET_CHECK(IsNullInput(fw_aux_input_to_input_weights_) &&
303                      IsNullInput(bw_aux_input_to_input_weights_));
304     }
305 
306     const bool aux_fw_weights_all_or_none =
307             ((params_.use_cifg || !IsNullInput(fw_aux_input_to_input_weights_)) &&
308              !IsNullInput(fw_aux_input_to_forget_weights_) &&
309              !IsNullInput(fw_aux_input_to_cell_weights_) &&
310              !IsNullInput(fw_aux_input_to_output_weights_)) ||
311             (IsNullInput(fw_aux_input_to_input_weights_) &&
312              IsNullInput(fw_aux_input_to_forget_weights_) &&
313              IsNullInput(fw_aux_input_to_cell_weights_) &&
314              IsNullInput(fw_aux_input_to_output_weights_));
315     const bool aux_bw_weights_all_or_none =
316             ((params_.use_cifg || !IsNullInput(bw_aux_input_to_input_weights_)) &&
317              !IsNullInput(bw_aux_input_to_forget_weights_) &&
318              !IsNullInput(bw_aux_input_to_cell_weights_) &&
319              !IsNullInput(bw_aux_input_to_output_weights_)) ||
320             (IsNullInput(bw_aux_input_to_input_weights_) &&
321              IsNullInput(bw_aux_input_to_forget_weights_) &&
322              IsNullInput(bw_aux_input_to_cell_weights_) &&
323              IsNullInput(bw_aux_input_to_output_weights_));
324 
325     NN_RET_CHECK(aux_fw_weights_all_or_none && aux_bw_weights_all_or_none);
326     const bool has_aux_input = !IsNullInput(aux_input_);
327     const bool has_fw_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
328     const bool has_bw_aux_weights = !IsNullInput(bw_aux_input_to_forget_weights_);
329 
330     NN_RET_CHECK(has_fw_aux_weights == has_bw_aux_weights);
331 
332     LinkingMode linkingMode;
333     NN_RET_CHECK(getLinkingMode(has_aux_input, has_fw_aux_weights, &linkingMode));
334 
335     if (has_aux_input) {
336         // Check that aux_input has the same dimensions (except last) as the input.
337         NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
338         NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
339     }
340 
341     if (has_fw_aux_weights) {
342         int n_aux_input = SizeOfDimension(input_, 2);
343 
344         // Check forward auxiliary input shapes
345         {
346             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_input_weights_), 2);
347             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 0), n_fw_cell);
348             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 1), n_aux_input);
349 
350             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_forget_weights_), 2);
351             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 0), n_fw_cell);
352             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 1), n_aux_input);
353 
354             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_cell_weights_), 2);
355             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 0), n_fw_cell);
356             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 1), n_aux_input);
357 
358             NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_output_weights_), 2);
359             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 0), n_fw_cell);
360             NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 1), n_aux_input);
361         }
362 
363         // Check backward auxiliary input shapes
364         {
365             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_input_weights_), 2);
366             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 0), n_bw_cell);
367             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 1), n_aux_input);
368 
369             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_forget_weights_), 2);
370             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 0), n_bw_cell);
371             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 1), n_aux_input);
372 
373             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_cell_weights_), 2);
374             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 0), n_bw_cell);
375             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 1), n_aux_input);
376 
377             NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_output_weights_), 2);
378             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 0), n_bw_cell);
379             NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 1), n_aux_input);
380         }
381     }
382 
383     const Shape& inputShape = input_->shape();
384     fwOutputShape->type = inputShape.type;
385     fwOutputShape->offset = inputShape.offset;
386     fwOutputShape->scale = inputShape.scale;
387     fwOutputShape->dimensions.resize(3);
388     fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
389     fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
390     fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
391 
392     const RunTimeOperandInfo* bw_input =
393             linkingMode == LinkingMode::PARALLEL_LINKING ? aux_input_ : input_;
394     const uint32_t n_bw_input = SizeOfDimension(bw_input, 2);
395     // Check that input tensor dimensions matches with each other.
396     if (!LSTMCell::CheckInputTensorDimensions(
397                 bw_input, bw_input_to_input_weights_, bw_input_to_forget_weights_,
398                 bw_input_to_cell_weights_, bw_input_to_output_weights_,
399                 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
400                 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
401                 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
402                 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
403                 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
404                 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
405                 bw_output_layer_norm_weights_, n_bw_input, n_bw_output, n_bw_cell, &params_)) {
406         return false;
407     }
408 
409     if (!params_.merge_outputs) {
410         bwOutputShape->type = inputShape.type;
411         bwOutputShape->offset = inputShape.offset;
412         bwOutputShape->scale = inputShape.scale;
413         bwOutputShape->dimensions.resize(3);
414         bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
415         bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
416         bwOutputShape->dimensions[2] = n_bw_output;
417     }
418 
419     if (params_.output_state) {
420         *fwOutputActivationState = fw_activation_state_->shape();
421         *fwOutputCellState = fw_cell_state_->shape();
422         *bwOutputActivationState = bw_activation_state_->shape();
423         *bwOutputCellState = bw_cell_state_->shape();
424     }
425 
426     if (params_.use_cifg) {
427         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
428         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
429     } else {
430         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
431         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
432     }
433     fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
434     fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
435     fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
436 
437     return true;
438 }
439 
Eval()440 bool BidirectionalSequenceLSTM::Eval() {
441     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
442     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
443     std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
444     fw_output_dims[2] = n_fw_output;
445     std::vector<uint32_t> bw_output_dims = fw_output_dims;
446     bw_output_dims[2] = n_bw_output;
447     const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
448     const uint32_t n_output_elements =
449             fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
450 
451     const bool has_aux_input = !IsNullInput(aux_input_);
452     const bool has_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
453 
454     LinkingMode linkingMode;
455     NN_RET_CHECK(getLinkingMode(has_aux_input, has_aux_weights, &linkingMode));
456 
457     switch (input_->type) {
458         case OperandType::TENSOR_FLOAT32: {
459             const float* bwInput = GetBuffer<const float>(input_);
460             Shape bwInputShape = input_->shape();
461             const float* auxInput = GetOptionalBuffer<const float>(aux_input_);
462             if (linkingMode == LinkingMode::PARALLEL_LINKING) {
463                 bwInput = GetBuffer<const float>(aux_input_);
464                 bwInputShape = aux_input_->shape();
465                 auxInput = nullptr;
466             }
467 
468             float* fw_output_activation_state_buffer = nullptr;
469             float* fw_output_cell_state_buffer = nullptr;
470             std::vector<float> fw_output_activation_state;
471             std::vector<float> fw_output_cell_state;
472             if (params_.output_state) {
473                 fw_output_activation_state_buffer = GetBuffer<float>(fw_output_activation_state_);
474                 fw_output_cell_state_buffer = GetBuffer<float>(fw_output_cell_state_);
475             } else {
476                 fw_output_activation_state.resize(
477                         getNumberOfElements(fw_activation_state_->shape()));
478                 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
479 
480                 fw_output_activation_state_buffer = fw_output_activation_state.data();
481                 fw_output_cell_state_buffer = fw_output_cell_state.data();
482             }
483             std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
484             const bool kForwardSequence = true;
485             LSTMCell::LSTMEvalFloat32(
486                     params_, GetBuffer<const float>(input_), input_->shape(),
487                     GetBuffer<const float>(fw_input_to_input_weights_),
488                     GetBuffer<const float>(fw_input_to_forget_weights_),
489                     GetBuffer<const float>(fw_input_to_cell_weights_),
490                     GetBuffer<const float>(fw_input_to_output_weights_),
491                     fw_input_to_output_weights_->shape(),
492                     GetBuffer<const float>(fw_recurrent_to_input_weights_),
493                     GetBuffer<const float>(fw_recurrent_to_forget_weights_),
494                     GetBuffer<const float>(fw_recurrent_to_cell_weights_),
495                     GetBuffer<const float>(fw_recurrent_to_output_weights_),
496                     fw_recurrent_to_output_weights_->shape(),
497                     GetBuffer<const float>(fw_cell_to_input_weights_),
498                     GetBuffer<const float>(fw_cell_to_forget_weights_),
499                     GetBuffer<const float>(fw_cell_to_output_weights_), auxInput,
500                     GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
501                     GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
502                     GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
503                     GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
504                     GetBuffer<const float>(fw_input_gate_bias_),
505                     GetBuffer<const float>(fw_forget_gate_bias_),
506                     GetBuffer<const float>(fw_cell_bias_),
507                     GetBuffer<const float>(fw_output_gate_bias_),
508                     GetBuffer<const float>(fw_projection_weights_),
509                     GetBuffer<const float>(fw_projection_bias_),
510                     GetBuffer<const float>(fw_activation_state_),
511                     GetBuffer<const float>(fw_cell_state_),
512                     GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
513                     GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
514                     GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
515                     GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
516                     fw_output_activation_state_buffer, fw_output_cell_state_buffer,
517                     GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
518                     kForwardSequence);
519 
520             float* bw_output_activation_state_buffer;
521             float* bw_output_cell_state_buffer;
522             std::vector<float> bw_output_activation_state;
523             std::vector<float> bw_output_cell_state;
524             if (params_.output_state) {
525                 bw_output_activation_state_buffer = GetBuffer<float>(bw_output_activation_state_);
526                 bw_output_cell_state_buffer = GetBuffer<float>(bw_output_cell_state_);
527             } else {
528                 bw_output_activation_state.resize(
529                         getNumberOfElements(bw_activation_state_->shape()));
530                 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
531 
532                 bw_output_activation_state_buffer = bw_output_activation_state.data();
533                 bw_output_cell_state_buffer = bw_output_cell_state.data();
534             }
535             std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
536             const bool kBackwardSequence = false;
537             LSTMCell::LSTMEvalFloat32(
538                     params_, bwInput, bwInputShape,
539                     GetBuffer<const float>(bw_input_to_input_weights_),
540                     GetBuffer<const float>(bw_input_to_forget_weights_),
541                     GetBuffer<const float>(bw_input_to_cell_weights_),
542                     GetBuffer<const float>(bw_input_to_output_weights_),
543                     bw_input_to_output_weights_->shape(),
544                     GetBuffer<const float>(bw_recurrent_to_input_weights_),
545                     GetBuffer<const float>(bw_recurrent_to_forget_weights_),
546                     GetBuffer<const float>(bw_recurrent_to_cell_weights_),
547                     GetBuffer<const float>(bw_recurrent_to_output_weights_),
548                     bw_recurrent_to_output_weights_->shape(),
549                     GetBuffer<const float>(bw_cell_to_input_weights_),
550                     GetBuffer<const float>(bw_cell_to_forget_weights_),
551                     GetBuffer<const float>(bw_cell_to_output_weights_), auxInput,
552                     GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
553                     GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
554                     GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
555                     GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
556                     GetBuffer<const float>(bw_input_gate_bias_),
557                     GetBuffer<const float>(bw_forget_gate_bias_),
558                     GetBuffer<const float>(bw_cell_bias_),
559                     GetBuffer<const float>(bw_output_gate_bias_),
560                     GetBuffer<const float>(bw_projection_weights_),
561                     GetBuffer<const float>(bw_projection_bias_),
562                     GetBuffer<const float>(bw_activation_state_),
563                     GetBuffer<const float>(bw_cell_state_),
564                     GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
565                     GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
566                     GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
567                     GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
568                     bw_output_activation_state_buffer, bw_output_cell_state_buffer,
569                     params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
570                                           : GetBuffer<float>(bw_output_),
571                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
572             if (params_.merge_outputs) {
573                 std::vector<float> temp(n_output_elements);
574                 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
575                                     GetBuffer<float>(fw_output_) + n_fw_output_elements,
576                                     bw_output_dims, temp.data());
577                 std::copy(temp.data(), temp.data() + n_output_elements,
578                           GetBuffer<float>(fw_output_));
579             }
580         } break;
581         case OperandType::TENSOR_FLOAT16: {
582             const _Float16* bwInput = GetBuffer<const _Float16>(input_);
583             Shape bwInputShape = input_->shape();
584             const _Float16* auxInput = GetOptionalBuffer<const _Float16>(aux_input_);
585             if (linkingMode == LinkingMode::PARALLEL_LINKING) {
586                 bwInput = GetBuffer<const _Float16>(aux_input_);
587                 bwInputShape = aux_input_->shape();
588                 auxInput = nullptr;
589             }
590 
591             _Float16* fw_output_activation_state_buffer;
592             _Float16* fw_output_cell_state_buffer;
593             std::vector<_Float16> fw_output_activation_state;
594             std::vector<_Float16> fw_output_cell_state;
595             if (params_.output_state) {
596                 fw_output_activation_state_buffer =
597                         GetBuffer<_Float16>(fw_output_activation_state_);
598                 fw_output_cell_state_buffer = GetBuffer<_Float16>(fw_output_cell_state_);
599             } else {
600                 fw_output_activation_state.resize(
601                         getNumberOfElements(fw_activation_state_->shape()));
602                 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
603 
604                 fw_output_activation_state_buffer = fw_output_activation_state.data();
605                 fw_output_cell_state_buffer = fw_output_cell_state.data();
606             }
607             std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
608             const bool kForwardSequence = true;
609             LSTMCell::LSTMEvalFloat16(
610                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
611                     GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
612                     GetBuffer<const _Float16>(fw_input_to_forget_weights_),
613                     GetBuffer<const _Float16>(fw_input_to_cell_weights_),
614                     GetBuffer<const _Float16>(fw_input_to_output_weights_),
615                     fw_input_to_output_weights_->shape(),
616                     GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
617                     GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
618                     GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
619                     GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
620                     fw_recurrent_to_output_weights_->shape(),
621                     GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
622                     GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
623                     GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_), auxInput,
624                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
625                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
626                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
627                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
628                     GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
629                     GetBuffer<const _Float16>(fw_forget_gate_bias_),
630                     GetBuffer<const _Float16>(fw_cell_bias_),
631                     GetBuffer<const _Float16>(fw_output_gate_bias_),
632                     GetOptionalBuffer<const _Float16>(fw_projection_weights_),
633                     GetOptionalBuffer<const _Float16>(fw_projection_bias_),
634                     GetBuffer<const _Float16>(fw_activation_state_),
635                     GetBuffer<const _Float16>(fw_cell_state_),
636                     GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
637                     GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
638                     GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
639                     GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
640                     fw_output_activation_state_buffer, fw_output_cell_state_buffer,
641                     GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
642                     kForwardSequence);
643 
644             _Float16* bw_output_activation_state_buffer;
645             _Float16* bw_output_cell_state_buffer;
646             std::vector<_Float16> bw_output_activation_state;
647             std::vector<_Float16> bw_output_cell_state;
648             if (params_.output_state) {
649                 bw_output_activation_state_buffer =
650                         GetBuffer<_Float16>(bw_output_activation_state_);
651                 bw_output_cell_state_buffer = GetBuffer<_Float16>(bw_output_cell_state_);
652             } else {
653                 bw_output_activation_state.resize(
654                         getNumberOfElements(bw_activation_state_->shape()));
655                 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
656 
657                 bw_output_activation_state_buffer = bw_output_activation_state.data();
658                 bw_output_cell_state_buffer = bw_output_cell_state.data();
659             }
660             std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
661             const bool kBackwardSequence = false;
662             LSTMCell::LSTMEvalFloat16(
663                     params_, bwInput, bwInputShape,
664                     GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
665                     GetBuffer<const _Float16>(bw_input_to_forget_weights_),
666                     GetBuffer<const _Float16>(bw_input_to_cell_weights_),
667                     GetBuffer<const _Float16>(bw_input_to_output_weights_),
668                     bw_input_to_output_weights_->shape(),
669                     GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
670                     GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
671                     GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
672                     GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
673                     bw_recurrent_to_output_weights_->shape(),
674                     GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
675                     GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
676                     GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_), auxInput,
677                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
678                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
679                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
680                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
681                     GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
682                     GetBuffer<const _Float16>(bw_forget_gate_bias_),
683                     GetBuffer<const _Float16>(bw_cell_bias_),
684                     GetBuffer<const _Float16>(bw_output_gate_bias_),
685                     GetOptionalBuffer<const _Float16>(bw_projection_weights_),
686                     GetOptionalBuffer<const _Float16>(bw_projection_bias_),
687                     GetBuffer<const _Float16>(bw_activation_state_),
688                     GetBuffer<const _Float16>(bw_cell_state_),
689                     GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
690                     GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
691                     GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
692                     GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
693                     bw_output_activation_state_buffer, bw_output_cell_state_buffer,
694                     params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
695                                           : GetBuffer<_Float16>(bw_output_),
696                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
697             if (params_.merge_outputs) {
698                 std::vector<_Float16> temp(n_output_elements);
699                 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
700                                     GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
701                                     bw_output_dims, temp.data());
702                 std::copy(temp.data(), temp.data() + n_output_elements,
703                           GetBuffer<_Float16>(fw_output_));
704             }
705         } break;
706         default: {
707             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
708             return false;
709         }
710     }
711     return true;
712 }
713 
714 }  // namespace nn
715 }  // namespace android
716