1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <vector>
20
21 #include "HalInterfaces.h"
22 #include "IndexedShapeWrapper.h"
23 #include "LSTM.h"
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26
27 namespace android {
28 namespace nn {
29 namespace unidirectional_sequence_lstm {
30
31 // Inputs
32 constexpr uint32_t kNumInputs = 28;
33
34 // Input tensor of size {max_time, n_batch, n_input}
35 constexpr uint32_t kInputTensor = 0;
36
37 // Input weight tensors of size: {n_cell, n_input}
38 constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional
39 constexpr uint32_t kInputToForgetWeightsTensor = 2;
40 constexpr uint32_t kInputToCellWeightsTensor = 3;
41 constexpr uint32_t kInputToOutputWeightsTensor = 4;
42
43 // Recurrent weight tensors of size {n_cell, n_output}
44 constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional
45 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
46 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
47 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
48
49 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
50 constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional
51 constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional
52 constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional
53
54 // Gates bias tensors of size {n_cell}
55 constexpr uint32_t kInputGateBiasTensor = 12; // Optional
56 constexpr uint32_t kForgetGateBiasTensor = 13;
57 constexpr uint32_t kCellGateBiasTensor = 14;
58 constexpr uint32_t kOutputGateBiasTensor = 15;
59
60 // Projection weight tensor of size {n_output, n_cell}
61 constexpr uint32_t kProjectionWeightsTensor = 16; // Optional
62 // Projection bias tensor of size {n_output}
63 constexpr uint32_t kProjectionBiasTensor = 17; // Optional
64
65 // Input from the output of the previous step, tensor of size {batch_size, n_output}
66 constexpr uint32_t kOutputStateInTensor = 18;
67 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
68 constexpr uint32_t kCellStateInTensor = 19;
69
70 constexpr uint32_t kActivationParam = 20;
71 constexpr uint32_t kCellClipParam = 21;
72 constexpr uint32_t kProjClipParam = 22;
73 constexpr uint32_t kTimeMajorParam = 23;
74
75 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
76 constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional
77 constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional
78 constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional
79 constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional
80
81 // Output tensors.
82 constexpr uint32_t kNumOutputs = 1;
83 constexpr uint32_t kNumOutputsWithState = 3;
84
85 constexpr uint32_t kOutputTensor = 0;
86 constexpr uint32_t kOutputStateOutTensor = 1;
87 constexpr uint32_t kCellStateOutTensor = 2;
88
89 namespace {
90
91 using namespace hal;
92
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)93 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
94 return context->getInputBuffer(tensor) != nullptr;
95 }
96
isTimeMajor(IOperationExecutionContext * context)97 inline bool isTimeMajor(IOperationExecutionContext* context) {
98 return context->getInputValue<bool>(kTimeMajorParam);
99 }
100
101 template <typename T>
getLSTMParams(IOperationExecutionContext * context)102 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
103 LSTMParams params;
104 params.activation =
105 static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
106 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
107 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
108 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
109 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
110 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
111 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
112 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
113 return params;
114 }
115
116 } // namespace
117
validate(const IOperationValidationContext * context)118 bool validate(const IOperationValidationContext* context) {
119 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
120 const uint32_t numOutputs = context->getNumOutputs();
121 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
122 const OperandType inputType = context->getInputType(kInputTensor);
123 std::vector<OperandType> inExpectedTypes;
124 std::vector<OperandType> outExpectedTypes;
125 if (inputType == OperandType::TENSOR_FLOAT32) {
126 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
127 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
128 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
129 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
130 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
131 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
132 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
133 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
134 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
135 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
136 OperandType::INT32, OperandType::FLOAT32,
137 OperandType::FLOAT32, OperandType::BOOL,
138 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
139 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
140 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
141 } else if (inputType == OperandType::TENSOR_FLOAT16) {
142 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
143 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
144 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
145 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
146 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
147 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
148 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
149 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
150 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
151 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
152 OperandType::INT32, OperandType::FLOAT16,
153 OperandType::FLOAT16, OperandType::BOOL,
154 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
155 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
156 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
157 } else {
158 NN_RET_CHECK_FAIL()
159 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
160 << toString(inputType);
161 }
162 HalVersion minHalVersionSupported = HalVersion::V1_2;
163 if (context->getNumOutputs() == kNumOutputsWithState) {
164 minHalVersionSupported = HalVersion::V1_3;
165 outExpectedTypes.insert(outExpectedTypes.end(), {inputType, inputType});
166 }
167 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
168 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
169 return validateHalVersion(context, minHalVersionSupported);
170 }
171
prepare(IOperationExecutionContext * context)172 bool prepare(IOperationExecutionContext* context) {
173 // Check that none of the required inputs are omitted
174 const std::vector<int> requiredInputs = {
175 kInputTensor,
176 kInputToForgetWeightsTensor,
177 kInputToCellWeightsTensor,
178 kInputToOutputWeightsTensor,
179 kRecurrentToForgetWeightsTensor,
180 kRecurrentToCellWeightsTensor,
181 kRecurrentToOutputWeightsTensor,
182 kForgetGateBiasTensor,
183 kCellGateBiasTensor,
184 kOutputGateBiasTensor,
185 kOutputStateInTensor,
186 kCellStateInTensor,
187 kActivationParam,
188 kCellClipParam,
189 kProjClipParam,
190 kTimeMajorParam,
191 };
192 for (const int requiredInput : requiredInputs) {
193 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
194 << "required input " << requiredInput << " is omitted";
195 }
196
197 const Shape inputShape = context->getInputShape(kInputTensor);
198 const uint32_t inputRank = getNumberOfDimensions(inputShape);
199 NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
200
201 const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
202 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
203 const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
204
205 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
206 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
207 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
208 const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
209
210 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
211 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
212 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
213 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
214
215 if (hasTensor(context, kInputToInputWeightsTensor)) {
216 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
217 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
218 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
219 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
220 }
221
222 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
223 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
224 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
225 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
226 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
227 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
228 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
229 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
230
231 if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
232 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
233 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
234 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
235 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
236 }
237
238 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
239 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
240 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
241 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
242 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
243 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
244 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
245 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
246
247 // We make sure the input-gate's parameters are either both present (regular
248 // LSTM) or not at all (CIFG-LSTM).
249 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
250 hasTensor(context, kRecurrentToInputWeightsTensor)) ||
251 (!hasTensor(context, kInputToInputWeightsTensor) &&
252 !hasTensor(context, kRecurrentToInputWeightsTensor));
253 NN_RET_CHECK(cifgWeightsAllOrNone);
254
255 if (hasTensor(context, kCellToInputWeightsTensor)) {
256 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
257 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
258 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
259 }
260
261 if (hasTensor(context, kCellToForgetWeightsTensor)) {
262 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
263 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
264 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
265 }
266
267 if (hasTensor(context, kCellToOutputWeightsTensor)) {
268 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
269 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
270 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
271 }
272
273 // Making sure the peephole weights are there all or none.
274 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
275 const bool peepholeWeightsAllOrNone =
276 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
277 hasTensor(context, kCellToForgetWeightsTensor) &&
278 hasTensor(context, kCellToOutputWeightsTensor)) ||
279 (!hasTensor(context, kCellToInputWeightsTensor) &&
280 !hasTensor(context, kCellToForgetWeightsTensor) &&
281 !hasTensor(context, kCellToOutputWeightsTensor));
282 NN_RET_CHECK(peepholeWeightsAllOrNone);
283
284 if (!cifgUsed) {
285 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
286 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
287 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
288 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
289 } else {
290 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
291 << "Input gate bias tensor is present when CIFG is used";
292 }
293
294 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
295 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
296 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
297 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
298 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
299 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
300 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
301 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
302 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
303
304 if (hasTensor(context, kProjectionWeightsTensor)) {
305 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
306 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
307 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
308 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
309 }
310
311 if (hasTensor(context, kProjectionBiasTensor)) {
312 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
313 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
314 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
315 }
316
317 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
318 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
319 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
320 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
321 const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
322 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
323 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
324 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
325
326 if (hasTensor(context, kInputLayerNormWeightsTensor)) {
327 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
328 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
329 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
330 }
331
332 if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
333 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
334 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
335 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
336 }
337
338 if (hasTensor(context, kCellLayerNormWeightsTensor)) {
339 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
340 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
341 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
342 }
343
344 if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
345 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
346 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
347 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
348 }
349
350 if (cifgUsed) {
351 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
352 << "Input layer norm weights tensor is present when CIFG is used";
353 const bool layerNormWeightsAllOrNoneCifg =
354 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
355 hasTensor(context, kCellLayerNormWeightsTensor) &&
356 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
357 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
358 !hasTensor(context, kCellLayerNormWeightsTensor) &&
359 !hasTensor(context, kOutputLayerNormWeightsTensor));
360 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
361 } else {
362 const bool layerNormWeightsAllOrNone =
363 (hasTensor(context, kInputLayerNormWeightsTensor) &&
364 hasTensor(context, kForgetLayerNormWeightsTensor) &&
365 hasTensor(context, kCellLayerNormWeightsTensor) &&
366 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
367 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
368 !hasTensor(context, kForgetLayerNormWeightsTensor) &&
369 !hasTensor(context, kCellLayerNormWeightsTensor) &&
370 !hasTensor(context, kOutputLayerNormWeightsTensor));
371 NN_RET_CHECK(layerNormWeightsAllOrNone);
372 }
373
374 Shape outputShape = context->getInputShape(kInputTensor);
375 outputShape.dimensions[2] = outputSize;
376
377 if (context->getNumOutputs() == kNumOutputsWithState) {
378 NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
379 NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
380
381 Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
382 outputStateOutTensor.dimensions.resize(2);
383 outputStateOutTensor.dimensions[0] = batchSize;
384 outputStateOutTensor.dimensions[1] = outputSize;
385 NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
386
387 Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
388 cellStateOutTensor.dimensions.resize(2);
389 cellStateOutTensor.dimensions[0] = batchSize;
390 cellStateOutTensor.dimensions[1] = numCells;
391 NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
392 }
393
394 return context->setOutputShape(kOutputTensor, outputShape);
395 }
396
execute(IOperationExecutionContext * context)397 bool execute(IOperationExecutionContext* context) {
398 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
399 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
400 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
401 const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
402 const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
403
404 const OperandType inputType = context->getInputType(kInputTensor);
405 switch (inputType) {
406 case OperandType::TENSOR_FLOAT32: {
407 // Initialize empty vectors and resize below only if needed
408 std::vector<float> outputStateOutBuffer;
409 std::vector<float> cellStateOutBuffer;
410 float* outputStateOut;
411 float* cellStateOut;
412 if (useStateOutTensors) {
413 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
414 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
415 } else {
416 outputStateOutBuffer.resize(outputStateSize);
417 cellStateOutBuffer.resize(cellStateSize);
418 outputStateOut = outputStateOutBuffer.data();
419 cellStateOut = cellStateOutBuffer.data();
420 }
421 std::vector<float> scratchBuffer(scratchSize);
422 LSTMCell::LSTMEvalFloat32(
423 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
424 context->getInputShape(kInputTensor),
425 context->getInputBuffer<float>(kInputToInputWeightsTensor),
426 context->getInputBuffer<float>(kInputToForgetWeightsTensor),
427 context->getInputBuffer<float>(kInputToCellWeightsTensor),
428 context->getInputBuffer<float>(kInputToOutputWeightsTensor),
429 context->getInputShape(kInputToOutputWeightsTensor),
430 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
431 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
432 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
433 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
434 context->getInputShape(kRecurrentToOutputWeightsTensor),
435 context->getInputBuffer<float>(kCellToInputWeightsTensor),
436 context->getInputBuffer<float>(kCellToForgetWeightsTensor),
437 context->getInputBuffer<float>(kCellToOutputWeightsTensor),
438 /*aux_input_buffer=*/nullptr,
439 /*aux_input_to_input_weights_buffer=*/nullptr,
440 /*aux_input_to_forget_weights_buffer=*/nullptr,
441 /*aux_input_to_cell_weights_buffer=*/nullptr,
442 /*aux_input_to_output_weights_buffer=*/nullptr,
443 context->getInputBuffer<float>(kInputGateBiasTensor),
444 context->getInputBuffer<float>(kForgetGateBiasTensor),
445 context->getInputBuffer<float>(kCellGateBiasTensor),
446 context->getInputBuffer<float>(kOutputGateBiasTensor),
447 context->getInputBuffer<float>(kProjectionWeightsTensor),
448 context->getInputBuffer<float>(kProjectionBiasTensor),
449 context->getInputBuffer<float>(kOutputStateInTensor),
450 context->getInputBuffer<float>(kCellStateInTensor),
451 context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
452 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
453 context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
454 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
455 cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
456 scratchBuffer.data(), isTimeMajor(context));
457 } break;
458 case OperandType::TENSOR_FLOAT16: {
459 // Initialize empty vectors and resize below only if needed
460 std::vector<_Float16> outputStateOutBuffer;
461 std::vector<_Float16> cellStateOutBuffer;
462 _Float16* outputStateOut;
463 _Float16* cellStateOut;
464 if (useStateOutTensors) {
465 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
466 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
467 } else {
468 outputStateOutBuffer.resize(outputStateSize);
469 cellStateOutBuffer.resize(cellStateSize);
470 outputStateOut = outputStateOutBuffer.data();
471 cellStateOut = cellStateOutBuffer.data();
472 }
473 std::vector<_Float16> scratchBuffer(scratchSize);
474 LSTMCell::LSTMEvalFloat16(
475 getLSTMParams<_Float16>(context),
476 context->getInputBuffer<_Float16>(kInputTensor),
477 context->getInputShape(kInputTensor),
478 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
479 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
480 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
481 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
482 context->getInputShape(kInputToOutputWeightsTensor),
483 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
484 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
485 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
486 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
487 context->getInputShape(kRecurrentToOutputWeightsTensor),
488 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
489 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
490 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
491 /*aux_input_buffer=*/nullptr,
492 /*aux_input_to_input_weights_buffer=*/nullptr,
493 /*aux_input_to_forget_weights_buffer=*/nullptr,
494 /*aux_input_to_cell_weights_buffer=*/nullptr,
495 /*aux_input_to_output_weights_buffer=*/nullptr,
496 context->getInputBuffer<_Float16>(kInputGateBiasTensor),
497 context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
498 context->getInputBuffer<_Float16>(kCellGateBiasTensor),
499 context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
500 context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
501 context->getInputBuffer<_Float16>(kProjectionBiasTensor),
502 context->getInputBuffer<_Float16>(kOutputStateInTensor),
503 context->getInputBuffer<_Float16>(kCellStateInTensor),
504 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
505 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
506 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
507 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
508 outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
509 scratchBuffer.data(), isTimeMajor(context));
510 } break;
511 default: {
512 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
513 return false;
514 }
515 }
516 return true;
517 }
518
519 } // namespace unidirectional_sequence_lstm
520
521 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
522 unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
523 unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
524
525 } // namespace nn
526 } // namespace android
527