1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "BidirectionalSequenceLSTM.h"
20
21 #include <algorithm>
22 #include <vector>
23
24 #include "CpuExecutor.h"
25 #include "CpuOperationUtils.h"
26 #include "HalInterfaces.h"
27 #include "OperationsUtils.h"
28 #include "Tracing.h"
29
30 namespace android {
31 namespace nn {
32
33 namespace {
34
35 using namespace hal;
36
37 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)38 inline T* GetBuffer(RunTimeOperandInfo* operand) {
39 return reinterpret_cast<T*>(operand->buffer);
40 }
41
42 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44 return reinterpret_cast<const T*>(operand->buffer);
45 }
46
47 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50 }
51
52 enum class LinkingMode {
53 NO_LINKING,
54 PARALLEL_LINKING,
55 CROSS_LINKING,
56 };
57
getLinkingMode(bool hasAuxInput,bool hasAuxWeights,LinkingMode * linkingMode)58 bool getLinkingMode(bool hasAuxInput, bool hasAuxWeights, LinkingMode* linkingMode) {
59 // Three possible configurations for three possible linking modes:
60 // 1) NO_LINKING -- no auxiliary tensors at all
61 // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
62 // input to the backward network, so the auxiliary weights are omitted.
63 // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
64 // auxiliary weights.
65 if (!hasAuxInput && !hasAuxWeights) {
66 *linkingMode = LinkingMode::NO_LINKING;
67 } else if (hasAuxInput && !hasAuxWeights) {
68 *linkingMode = LinkingMode::PARALLEL_LINKING;
69 } else if (hasAuxInput && hasAuxWeights) {
70 *linkingMode = LinkingMode::CROSS_LINKING;
71 } else {
72 NN_RET_CHECK_FAIL()
73 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
74 }
75
76 return true;
77 }
78
79 } // anonymous namespace
80
BidirectionalSequenceLSTM(const Operation & operation,RunTimeOperandInfo * operands)81 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
82 RunTimeOperandInfo* operands) {
83 input_ = GetInput(operation, operands, kInputTensor);
84
85 fw_input_to_input_weights_ =
86 GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional
87 fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
88 fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
89 fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
90
91 fw_recurrent_to_input_weights_ =
92 GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional
93 fw_recurrent_to_forget_weights_ =
94 GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
95 fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
96 fw_recurrent_to_output_weights_ =
97 GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
98
99 fw_cell_to_input_weights_ =
100 GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional
101 fw_cell_to_forget_weights_ =
102 GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional
103 fw_cell_to_output_weights_ =
104 GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional
105
106 fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
107 fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
108 fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
109 fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
110
111 fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional
112 fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional
113
114 fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
115 fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
116
117 bw_input_to_input_weights_ =
118 GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional
119 bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
120 bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
121 bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
122
123 bw_recurrent_to_input_weights_ =
124 GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional
125 bw_recurrent_to_forget_weights_ =
126 GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
127 bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
128 bw_recurrent_to_output_weights_ =
129 GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
130
131 bw_cell_to_input_weights_ =
132 GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional
133 bw_cell_to_forget_weights_ =
134 GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional
135 bw_cell_to_output_weights_ =
136 GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional
137
138 bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
139 bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
140 bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
141 bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
142
143 bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional
144 bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional
145
146 bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
147 bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
148
149 aux_input_ = GetInput(operation, operands, kAuxInputTensor);
150 fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
151 fw_aux_input_to_forget_weights_ =
152 GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
153 fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
154 fw_aux_input_to_output_weights_ =
155 GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
156 bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
157 bw_aux_input_to_forget_weights_ =
158 GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
159 bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
160 bw_aux_input_to_output_weights_ =
161 GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
162
163 fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
164 fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
165 fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
166 fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
167 bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
168 bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
169 bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
170 bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
171
172 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
173 params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
174 activationOperand, TfLiteFusedActivation::kTfLiteActNone));
175 const auto& clipOperand = *GetInput(operation, operands, kCellClipParam);
176 const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
177 if (input_->type == OperandType::TENSOR_FLOAT32) {
178 params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f);
179 params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f);
180 } else {
181 params_.cell_clip =
182 static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f));
183 params_.proj_clip =
184 static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f));
185 }
186 const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
187 params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false);
188 const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
189 params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false);
190 params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
191
192 fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
193 if (!params_.merge_outputs) {
194 bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
195 }
196
197 params_.output_state = (operation.outputs.size() == 5 || operation.outputs.size() == 6);
198 if (params_.output_state) {
199 uint32_t delta = params_.merge_outputs ? 1 : 0;
200 fw_output_activation_state_ =
201 GetOutput(operation, operands, kFwOutputActivationStateTensor - delta);
202 fw_output_cell_state_ = GetOutput(operation, operands, kFwOutputCellStateTensor - delta);
203 bw_output_activation_state_ =
204 GetOutput(operation, operands, kBwOutputActivationStateTensor - delta);
205 bw_output_cell_state_ = GetOutput(operation, operands, kBwOutputCellStateTensor - delta);
206 }
207 }
208
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * fwOutputShape,Shape * bwOutputShape,Shape * fwOutputActivationState,Shape * fwOutputCellState,Shape * bwOutputActivationState,Shape * bwOutputCellState)209 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
210 Shape* fwOutputShape, Shape* bwOutputShape,
211 Shape* fwOutputActivationState, Shape* fwOutputCellState,
212 Shape* bwOutputActivationState, Shape* bwOutputCellState) {
213 // Check we have all the inputs and outputs we need.
214 constexpr int requiredInputs[] = {
215 kInputTensor,
216 kFwInputToForgetWeightsTensor,
217 kFwInputToCellWeightsTensor,
218 kFwInputToOutputWeightsTensor,
219 kFwRecurrentToForgetWeightsTensor,
220 kFwRecurrentToCellWeightsTensor,
221 kFwRecurrentToOutputWeightsTensor,
222 kFwForgetGateBiasTensor,
223 kFwCellGateBiasTensor,
224 kFwOutputGateBiasTensor,
225 kBwInputToForgetWeightsTensor,
226 kBwInputToCellWeightsTensor,
227 kBwInputToOutputWeightsTensor,
228 kBwRecurrentToForgetWeightsTensor,
229 kBwRecurrentToCellWeightsTensor,
230 kBwRecurrentToOutputWeightsTensor,
231 kBwForgetGateBiasTensor,
232 kBwCellGateBiasTensor,
233 kBwOutputGateBiasTensor,
234 kFwInputActivationStateTensor,
235 kFwInputCellStateTensor,
236 kBwInputActivationStateTensor,
237 kBwInputCellStateTensor,
238 kActivationParam,
239 kCellClipParam,
240 kProjClipParam,
241 kMergeOutputsParam,
242 kTimeMajorParam,
243 };
244 for (const int requiredInput : requiredInputs) {
245 NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
246 << "required input " << requiredInput << " is omitted";
247 }
248
249 // Check that the scalar operands' buffers are large enough.
250 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
251 NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
252 const auto& cellOperand = *GetInput(operation, operands, kCellClipParam);
253 const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
254 if (input_->type == OperandType::TENSOR_FLOAT32) {
255 NN_RET_CHECK(cellOperand.length >= sizeof(float));
256 NN_RET_CHECK(projOperand.length >= sizeof(float));
257 } else {
258 NN_RET_CHECK(cellOperand.length >= sizeof(_Float16));
259 NN_RET_CHECK(projOperand.length >= sizeof(_Float16));
260 }
261 const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
262 NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool));
263 const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
264 NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool));
265
266 // Inferring batch size, number of outputs and number of cells from the
267 // input tensors.
268 NN_CHECK(NumDimensions(input_) == 3);
269 const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
270 const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
271 const uint32_t n_fw_input = SizeOfDimension(input_, 2);
272
273 const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
274 NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
275 NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_fw_input);
276
277 NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
278 NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
279 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
280
281 const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
282
283 NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
284 NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
285 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
286
287 // Check that input tensor dimensions matches with each other.
288 if (!LSTMCell::CheckInputTensorDimensions(
289 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
290 fw_input_to_cell_weights_, fw_input_to_output_weights_,
291 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
292 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
293 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
294 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
295 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
296 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
297 fw_output_layer_norm_weights_, n_fw_input, n_fw_output, n_fw_cell, ¶ms_)) {
298 return false;
299 }
300
301 if (params_.use_cifg) {
302 NN_RET_CHECK(IsNullInput(fw_aux_input_to_input_weights_) &&
303 IsNullInput(bw_aux_input_to_input_weights_));
304 }
305
306 const bool aux_fw_weights_all_or_none =
307 ((params_.use_cifg || !IsNullInput(fw_aux_input_to_input_weights_)) &&
308 !IsNullInput(fw_aux_input_to_forget_weights_) &&
309 !IsNullInput(fw_aux_input_to_cell_weights_) &&
310 !IsNullInput(fw_aux_input_to_output_weights_)) ||
311 (IsNullInput(fw_aux_input_to_input_weights_) &&
312 IsNullInput(fw_aux_input_to_forget_weights_) &&
313 IsNullInput(fw_aux_input_to_cell_weights_) &&
314 IsNullInput(fw_aux_input_to_output_weights_));
315 const bool aux_bw_weights_all_or_none =
316 ((params_.use_cifg || !IsNullInput(bw_aux_input_to_input_weights_)) &&
317 !IsNullInput(bw_aux_input_to_forget_weights_) &&
318 !IsNullInput(bw_aux_input_to_cell_weights_) &&
319 !IsNullInput(bw_aux_input_to_output_weights_)) ||
320 (IsNullInput(bw_aux_input_to_input_weights_) &&
321 IsNullInput(bw_aux_input_to_forget_weights_) &&
322 IsNullInput(bw_aux_input_to_cell_weights_) &&
323 IsNullInput(bw_aux_input_to_output_weights_));
324
325 NN_RET_CHECK(aux_fw_weights_all_or_none && aux_bw_weights_all_or_none);
326 const bool has_aux_input = !IsNullInput(aux_input_);
327 const bool has_fw_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
328 const bool has_bw_aux_weights = !IsNullInput(bw_aux_input_to_forget_weights_);
329
330 NN_RET_CHECK(has_fw_aux_weights == has_bw_aux_weights);
331
332 LinkingMode linkingMode;
333 NN_RET_CHECK(getLinkingMode(has_aux_input, has_fw_aux_weights, &linkingMode));
334
335 if (has_aux_input) {
336 // Check that aux_input has the same dimensions (except last) as the input.
337 NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
338 NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
339 }
340
341 if (has_fw_aux_weights) {
342 int n_aux_input = SizeOfDimension(input_, 2);
343
344 // Check forward auxiliary input shapes
345 {
346 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_input_weights_), 2);
347 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 0), n_fw_cell);
348 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_input_weights_, 1), n_aux_input);
349
350 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_forget_weights_), 2);
351 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 0), n_fw_cell);
352 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_forget_weights_, 1), n_aux_input);
353
354 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_cell_weights_), 2);
355 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 0), n_fw_cell);
356 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_cell_weights_, 1), n_aux_input);
357
358 NN_RET_CHECK_EQ(NumDimensions(fw_aux_input_to_output_weights_), 2);
359 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 0), n_fw_cell);
360 NN_RET_CHECK_EQ(SizeOfDimension(fw_aux_input_to_output_weights_, 1), n_aux_input);
361 }
362
363 // Check backward auxiliary input shapes
364 {
365 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_input_weights_), 2);
366 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 0), n_bw_cell);
367 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_input_weights_, 1), n_aux_input);
368
369 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_forget_weights_), 2);
370 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 0), n_bw_cell);
371 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_forget_weights_, 1), n_aux_input);
372
373 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_cell_weights_), 2);
374 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 0), n_bw_cell);
375 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_cell_weights_, 1), n_aux_input);
376
377 NN_RET_CHECK_EQ(NumDimensions(bw_aux_input_to_output_weights_), 2);
378 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 0), n_bw_cell);
379 NN_RET_CHECK_EQ(SizeOfDimension(bw_aux_input_to_output_weights_, 1), n_aux_input);
380 }
381 }
382
383 const Shape& inputShape = input_->shape();
384 fwOutputShape->type = inputShape.type;
385 fwOutputShape->offset = inputShape.offset;
386 fwOutputShape->scale = inputShape.scale;
387 fwOutputShape->dimensions.resize(3);
388 fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
389 fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
390 fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
391
392 const RunTimeOperandInfo* bw_input =
393 linkingMode == LinkingMode::PARALLEL_LINKING ? aux_input_ : input_;
394 const uint32_t n_bw_input = SizeOfDimension(bw_input, 2);
395 // Check that input tensor dimensions matches with each other.
396 if (!LSTMCell::CheckInputTensorDimensions(
397 bw_input, bw_input_to_input_weights_, bw_input_to_forget_weights_,
398 bw_input_to_cell_weights_, bw_input_to_output_weights_,
399 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
400 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
401 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
402 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
403 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
404 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
405 bw_output_layer_norm_weights_, n_bw_input, n_bw_output, n_bw_cell, ¶ms_)) {
406 return false;
407 }
408
409 if (!params_.merge_outputs) {
410 bwOutputShape->type = inputShape.type;
411 bwOutputShape->offset = inputShape.offset;
412 bwOutputShape->scale = inputShape.scale;
413 bwOutputShape->dimensions.resize(3);
414 bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
415 bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
416 bwOutputShape->dimensions[2] = n_bw_output;
417 }
418
419 if (params_.output_state) {
420 *fwOutputActivationState = fw_activation_state_->shape();
421 *fwOutputCellState = fw_cell_state_->shape();
422 *bwOutputActivationState = bw_activation_state_->shape();
423 *bwOutputCellState = bw_cell_state_->shape();
424 }
425
426 if (params_.use_cifg) {
427 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
428 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
429 } else {
430 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
431 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
432 }
433 fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
434 fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
435 fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
436
437 return true;
438 }
439
Eval()440 bool BidirectionalSequenceLSTM::Eval() {
441 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
442 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
443 std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
444 fw_output_dims[2] = n_fw_output;
445 std::vector<uint32_t> bw_output_dims = fw_output_dims;
446 bw_output_dims[2] = n_bw_output;
447 const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
448 const uint32_t n_output_elements =
449 fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
450
451 const bool has_aux_input = !IsNullInput(aux_input_);
452 const bool has_aux_weights = !IsNullInput(fw_aux_input_to_forget_weights_);
453
454 LinkingMode linkingMode;
455 NN_RET_CHECK(getLinkingMode(has_aux_input, has_aux_weights, &linkingMode));
456
457 switch (input_->type) {
458 case OperandType::TENSOR_FLOAT32: {
459 const float* bwInput = GetBuffer<const float>(input_);
460 Shape bwInputShape = input_->shape();
461 const float* auxInput = GetOptionalBuffer<const float>(aux_input_);
462 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
463 bwInput = GetBuffer<const float>(aux_input_);
464 bwInputShape = aux_input_->shape();
465 auxInput = nullptr;
466 }
467
468 float* fw_output_activation_state_buffer = nullptr;
469 float* fw_output_cell_state_buffer = nullptr;
470 std::vector<float> fw_output_activation_state;
471 std::vector<float> fw_output_cell_state;
472 if (params_.output_state) {
473 fw_output_activation_state_buffer = GetBuffer<float>(fw_output_activation_state_);
474 fw_output_cell_state_buffer = GetBuffer<float>(fw_output_cell_state_);
475 } else {
476 fw_output_activation_state.resize(
477 getNumberOfElements(fw_activation_state_->shape()));
478 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
479
480 fw_output_activation_state_buffer = fw_output_activation_state.data();
481 fw_output_cell_state_buffer = fw_output_cell_state.data();
482 }
483 std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
484 const bool kForwardSequence = true;
485 LSTMCell::LSTMEvalFloat32(
486 params_, GetBuffer<const float>(input_), input_->shape(),
487 GetBuffer<const float>(fw_input_to_input_weights_),
488 GetBuffer<const float>(fw_input_to_forget_weights_),
489 GetBuffer<const float>(fw_input_to_cell_weights_),
490 GetBuffer<const float>(fw_input_to_output_weights_),
491 fw_input_to_output_weights_->shape(),
492 GetBuffer<const float>(fw_recurrent_to_input_weights_),
493 GetBuffer<const float>(fw_recurrent_to_forget_weights_),
494 GetBuffer<const float>(fw_recurrent_to_cell_weights_),
495 GetBuffer<const float>(fw_recurrent_to_output_weights_),
496 fw_recurrent_to_output_weights_->shape(),
497 GetBuffer<const float>(fw_cell_to_input_weights_),
498 GetBuffer<const float>(fw_cell_to_forget_weights_),
499 GetBuffer<const float>(fw_cell_to_output_weights_), auxInput,
500 GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
501 GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
502 GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
503 GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
504 GetBuffer<const float>(fw_input_gate_bias_),
505 GetBuffer<const float>(fw_forget_gate_bias_),
506 GetBuffer<const float>(fw_cell_bias_),
507 GetBuffer<const float>(fw_output_gate_bias_),
508 GetBuffer<const float>(fw_projection_weights_),
509 GetBuffer<const float>(fw_projection_bias_),
510 GetBuffer<const float>(fw_activation_state_),
511 GetBuffer<const float>(fw_cell_state_),
512 GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
513 GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
514 GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
515 GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
516 fw_output_activation_state_buffer, fw_output_cell_state_buffer,
517 GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
518 kForwardSequence);
519
520 float* bw_output_activation_state_buffer;
521 float* bw_output_cell_state_buffer;
522 std::vector<float> bw_output_activation_state;
523 std::vector<float> bw_output_cell_state;
524 if (params_.output_state) {
525 bw_output_activation_state_buffer = GetBuffer<float>(bw_output_activation_state_);
526 bw_output_cell_state_buffer = GetBuffer<float>(bw_output_cell_state_);
527 } else {
528 bw_output_activation_state.resize(
529 getNumberOfElements(bw_activation_state_->shape()));
530 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
531
532 bw_output_activation_state_buffer = bw_output_activation_state.data();
533 bw_output_cell_state_buffer = bw_output_cell_state.data();
534 }
535 std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
536 const bool kBackwardSequence = false;
537 LSTMCell::LSTMEvalFloat32(
538 params_, bwInput, bwInputShape,
539 GetBuffer<const float>(bw_input_to_input_weights_),
540 GetBuffer<const float>(bw_input_to_forget_weights_),
541 GetBuffer<const float>(bw_input_to_cell_weights_),
542 GetBuffer<const float>(bw_input_to_output_weights_),
543 bw_input_to_output_weights_->shape(),
544 GetBuffer<const float>(bw_recurrent_to_input_weights_),
545 GetBuffer<const float>(bw_recurrent_to_forget_weights_),
546 GetBuffer<const float>(bw_recurrent_to_cell_weights_),
547 GetBuffer<const float>(bw_recurrent_to_output_weights_),
548 bw_recurrent_to_output_weights_->shape(),
549 GetBuffer<const float>(bw_cell_to_input_weights_),
550 GetBuffer<const float>(bw_cell_to_forget_weights_),
551 GetBuffer<const float>(bw_cell_to_output_weights_), auxInput,
552 GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
553 GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
554 GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
555 GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
556 GetBuffer<const float>(bw_input_gate_bias_),
557 GetBuffer<const float>(bw_forget_gate_bias_),
558 GetBuffer<const float>(bw_cell_bias_),
559 GetBuffer<const float>(bw_output_gate_bias_),
560 GetBuffer<const float>(bw_projection_weights_),
561 GetBuffer<const float>(bw_projection_bias_),
562 GetBuffer<const float>(bw_activation_state_),
563 GetBuffer<const float>(bw_cell_state_),
564 GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
565 GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
566 GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
567 GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
568 bw_output_activation_state_buffer, bw_output_cell_state_buffer,
569 params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
570 : GetBuffer<float>(bw_output_),
571 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
572 if (params_.merge_outputs) {
573 std::vector<float> temp(n_output_elements);
574 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
575 GetBuffer<float>(fw_output_) + n_fw_output_elements,
576 bw_output_dims, temp.data());
577 std::copy(temp.data(), temp.data() + n_output_elements,
578 GetBuffer<float>(fw_output_));
579 }
580 } break;
581 case OperandType::TENSOR_FLOAT16: {
582 const _Float16* bwInput = GetBuffer<const _Float16>(input_);
583 Shape bwInputShape = input_->shape();
584 const _Float16* auxInput = GetOptionalBuffer<const _Float16>(aux_input_);
585 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
586 bwInput = GetBuffer<const _Float16>(aux_input_);
587 bwInputShape = aux_input_->shape();
588 auxInput = nullptr;
589 }
590
591 _Float16* fw_output_activation_state_buffer;
592 _Float16* fw_output_cell_state_buffer;
593 std::vector<_Float16> fw_output_activation_state;
594 std::vector<_Float16> fw_output_cell_state;
595 if (params_.output_state) {
596 fw_output_activation_state_buffer =
597 GetBuffer<_Float16>(fw_output_activation_state_);
598 fw_output_cell_state_buffer = GetBuffer<_Float16>(fw_output_cell_state_);
599 } else {
600 fw_output_activation_state.resize(
601 getNumberOfElements(fw_activation_state_->shape()));
602 fw_output_cell_state.resize(getNumberOfElements(fw_cell_state_->shape()));
603
604 fw_output_activation_state_buffer = fw_output_activation_state.data();
605 fw_output_cell_state_buffer = fw_output_cell_state.data();
606 }
607 std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
608 const bool kForwardSequence = true;
609 LSTMCell::LSTMEvalFloat16(
610 params_, GetBuffer<const _Float16>(input_), input_->shape(),
611 GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
612 GetBuffer<const _Float16>(fw_input_to_forget_weights_),
613 GetBuffer<const _Float16>(fw_input_to_cell_weights_),
614 GetBuffer<const _Float16>(fw_input_to_output_weights_),
615 fw_input_to_output_weights_->shape(),
616 GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
617 GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
618 GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
619 GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
620 fw_recurrent_to_output_weights_->shape(),
621 GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
622 GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
623 GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_), auxInput,
624 GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
625 GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
626 GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
627 GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
628 GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
629 GetBuffer<const _Float16>(fw_forget_gate_bias_),
630 GetBuffer<const _Float16>(fw_cell_bias_),
631 GetBuffer<const _Float16>(fw_output_gate_bias_),
632 GetOptionalBuffer<const _Float16>(fw_projection_weights_),
633 GetOptionalBuffer<const _Float16>(fw_projection_bias_),
634 GetBuffer<const _Float16>(fw_activation_state_),
635 GetBuffer<const _Float16>(fw_cell_state_),
636 GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
637 GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
638 GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
639 GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
640 fw_output_activation_state_buffer, fw_output_cell_state_buffer,
641 GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
642 kForwardSequence);
643
644 _Float16* bw_output_activation_state_buffer;
645 _Float16* bw_output_cell_state_buffer;
646 std::vector<_Float16> bw_output_activation_state;
647 std::vector<_Float16> bw_output_cell_state;
648 if (params_.output_state) {
649 bw_output_activation_state_buffer =
650 GetBuffer<_Float16>(bw_output_activation_state_);
651 bw_output_cell_state_buffer = GetBuffer<_Float16>(bw_output_cell_state_);
652 } else {
653 bw_output_activation_state.resize(
654 getNumberOfElements(bw_activation_state_->shape()));
655 bw_output_cell_state.resize(getNumberOfElements(bw_cell_state_->shape()));
656
657 bw_output_activation_state_buffer = bw_output_activation_state.data();
658 bw_output_cell_state_buffer = bw_output_cell_state.data();
659 }
660 std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
661 const bool kBackwardSequence = false;
662 LSTMCell::LSTMEvalFloat16(
663 params_, bwInput, bwInputShape,
664 GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
665 GetBuffer<const _Float16>(bw_input_to_forget_weights_),
666 GetBuffer<const _Float16>(bw_input_to_cell_weights_),
667 GetBuffer<const _Float16>(bw_input_to_output_weights_),
668 bw_input_to_output_weights_->shape(),
669 GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
670 GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
671 GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
672 GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
673 bw_recurrent_to_output_weights_->shape(),
674 GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
675 GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
676 GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_), auxInput,
677 GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
678 GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
679 GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
680 GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
681 GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
682 GetBuffer<const _Float16>(bw_forget_gate_bias_),
683 GetBuffer<const _Float16>(bw_cell_bias_),
684 GetBuffer<const _Float16>(bw_output_gate_bias_),
685 GetOptionalBuffer<const _Float16>(bw_projection_weights_),
686 GetOptionalBuffer<const _Float16>(bw_projection_bias_),
687 GetBuffer<const _Float16>(bw_activation_state_),
688 GetBuffer<const _Float16>(bw_cell_state_),
689 GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
690 GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
691 GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
692 GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
693 bw_output_activation_state_buffer, bw_output_cell_state_buffer,
694 params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
695 : GetBuffer<_Float16>(bw_output_),
696 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
697 if (params_.merge_outputs) {
698 std::vector<_Float16> temp(n_output_elements);
699 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
700 GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
701 bw_output_dims, temp.data());
702 std::copy(temp.data(), temp.data() + n_output_elements,
703 GetBuffer<_Float16>(fw_output_));
704 }
705 } break;
706 default: {
707 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
708 return false;
709 }
710 }
711 return true;
712 }
713
714 } // namespace nn
715 } // namespace android
716