1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <vector>
20 
21 #include "HalInterfaces.h"
22 #include "IndexedShapeWrapper.h"
23 #include "LSTM.h"
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26 
27 namespace android {
28 namespace nn {
29 namespace unidirectional_sequence_lstm {
30 
31 // Inputs
32 constexpr uint32_t kNumInputs = 28;
33 
34 // Input tensor of size {max_time, n_batch, n_input}
35 constexpr uint32_t kInputTensor = 0;
36 
37 // Input weight tensors of size: {n_cell, n_input}
38 constexpr uint32_t kInputToInputWeightsTensor = 1;  // Optional
39 constexpr uint32_t kInputToForgetWeightsTensor = 2;
40 constexpr uint32_t kInputToCellWeightsTensor = 3;
41 constexpr uint32_t kInputToOutputWeightsTensor = 4;
42 
43 // Recurrent weight tensors of size {n_cell, n_output}
44 constexpr uint32_t kRecurrentToInputWeightsTensor = 5;  // Optional
45 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
46 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
47 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
48 
49 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
50 constexpr uint32_t kCellToInputWeightsTensor = 9;    // Optional
51 constexpr uint32_t kCellToForgetWeightsTensor = 10;  // Optional
52 constexpr uint32_t kCellToOutputWeightsTensor = 11;  // Optional
53 
54 // Gates bias tensors of size {n_cell}
55 constexpr uint32_t kInputGateBiasTensor = 12;  // Optional
56 constexpr uint32_t kForgetGateBiasTensor = 13;
57 constexpr uint32_t kCellGateBiasTensor = 14;
58 constexpr uint32_t kOutputGateBiasTensor = 15;
59 
60 // Projection weight tensor of size {n_output, n_cell}
61 constexpr uint32_t kProjectionWeightsTensor = 16;  // Optional
62 // Projection bias tensor of size {n_output}
63 constexpr uint32_t kProjectionBiasTensor = 17;  // Optional
64 
65 // Input from the output of the previous step, tensor of size {batch_size, n_output}
66 constexpr uint32_t kOutputStateInTensor = 18;
67 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
68 constexpr uint32_t kCellStateInTensor = 19;
69 
70 constexpr uint32_t kActivationParam = 20;
71 constexpr uint32_t kCellClipParam = 21;
72 constexpr uint32_t kProjClipParam = 22;
73 constexpr uint32_t kTimeMajorParam = 23;
74 
75 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
76 constexpr uint32_t kInputLayerNormWeightsTensor = 24;   // Optional
77 constexpr uint32_t kForgetLayerNormWeightsTensor = 25;  // Optional
78 constexpr uint32_t kCellLayerNormWeightsTensor = 26;    // Optional
79 constexpr uint32_t kOutputLayerNormWeightsTensor = 27;  // Optional
80 
81 // Output tensors.
82 constexpr uint32_t kNumOutputs = 1;
83 constexpr uint32_t kNumOutputsWithState = 3;
84 
85 constexpr uint32_t kOutputTensor = 0;
86 constexpr uint32_t kOutputStateOutTensor = 1;
87 constexpr uint32_t kCellStateOutTensor = 2;
88 
89 namespace {
90 
91 using namespace hal;
92 
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)93 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
94     return context->getInputBuffer(tensor) != nullptr;
95 }
96 
isTimeMajor(IOperationExecutionContext * context)97 inline bool isTimeMajor(IOperationExecutionContext* context) {
98     return context->getInputValue<bool>(kTimeMajorParam);
99 }
100 
101 template <typename T>
getLSTMParams(IOperationExecutionContext * context)102 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
103     LSTMParams params;
104     params.activation =
105             static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
106     params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
107     params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
108     params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
109     params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
110     params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
111     params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
112     params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
113     return params;
114 }
115 
116 }  // namespace
117 
validate(const IOperationValidationContext * context)118 bool validate(const IOperationValidationContext* context) {
119     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
120     const uint32_t numOutputs = context->getNumOutputs();
121     NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
122     const OperandType inputType = context->getInputType(kInputTensor);
123     std::vector<OperandType> inExpectedTypes;
124     std::vector<OperandType> outExpectedTypes;
125     if (inputType == OperandType::TENSOR_FLOAT32) {
126         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
127                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
128                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
129                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
130                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
131                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
132                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
133                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
134                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
135                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
136                            OperandType::INT32,          OperandType::FLOAT32,
137                            OperandType::FLOAT32,        OperandType::BOOL,
138                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
139                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
140         outExpectedTypes = {OperandType::TENSOR_FLOAT32};
141     } else if (inputType == OperandType::TENSOR_FLOAT16) {
142         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
143                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
144                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
145                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
146                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
147                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
148                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
149                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
150                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
151                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
152                            OperandType::INT32,          OperandType::FLOAT16,
153                            OperandType::FLOAT16,        OperandType::BOOL,
154                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
155                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
156         outExpectedTypes = {OperandType::TENSOR_FLOAT16};
157     } else {
158         NN_RET_CHECK_FAIL()
159                 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
160                 << toString(inputType);
161     }
162     HalVersion minHalVersionSupported = HalVersion::V1_2;
163     if (context->getNumOutputs() == kNumOutputsWithState) {
164         minHalVersionSupported = HalVersion::V1_3;
165         outExpectedTypes.insert(outExpectedTypes.end(), {inputType, inputType});
166     }
167     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
168     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
169     return validateHalVersion(context, minHalVersionSupported);
170 }
171 
prepare(IOperationExecutionContext * context)172 bool prepare(IOperationExecutionContext* context) {
173     // Check that none of the required inputs are omitted
174     const std::vector<int> requiredInputs = {
175             kInputTensor,
176             kInputToForgetWeightsTensor,
177             kInputToCellWeightsTensor,
178             kInputToOutputWeightsTensor,
179             kRecurrentToForgetWeightsTensor,
180             kRecurrentToCellWeightsTensor,
181             kRecurrentToOutputWeightsTensor,
182             kForgetGateBiasTensor,
183             kCellGateBiasTensor,
184             kOutputGateBiasTensor,
185             kOutputStateInTensor,
186             kCellStateInTensor,
187             kActivationParam,
188             kCellClipParam,
189             kProjClipParam,
190             kTimeMajorParam,
191     };
192     for (const int requiredInput : requiredInputs) {
193         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
194                 << "required input " << requiredInput << " is omitted";
195     }
196 
197     const Shape inputShape = context->getInputShape(kInputTensor);
198     const uint32_t inputRank = getNumberOfDimensions(inputShape);
199     NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
200 
201     const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
202     const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
203     const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
204 
205     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
206     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
207     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
208     const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
209 
210     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
211     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
212     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
213     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
214 
215     if (hasTensor(context, kInputToInputWeightsTensor)) {
216         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
217         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
218         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
219         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
220     }
221 
222     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
223     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
224     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
225     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
226     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
227     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
228     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
229     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
230 
231     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
232         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
233         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
234         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
235         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
236     }
237 
238     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
239     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
240     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
241     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
242     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
243     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
244     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
245     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
246 
247     // We make sure the input-gate's parameters are either both present (regular
248     // LSTM) or not at all (CIFG-LSTM).
249     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
250                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
251                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
252                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
253     NN_RET_CHECK(cifgWeightsAllOrNone);
254 
255     if (hasTensor(context, kCellToInputWeightsTensor)) {
256         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
257         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
258         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
259     }
260 
261     if (hasTensor(context, kCellToForgetWeightsTensor)) {
262         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
263         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
264         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
265     }
266 
267     if (hasTensor(context, kCellToOutputWeightsTensor)) {
268         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
269         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
270         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
271     }
272 
273     // Making sure the peephole weights are there all or none.
274     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
275     const bool peepholeWeightsAllOrNone =
276             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
277              hasTensor(context, kCellToForgetWeightsTensor) &&
278              hasTensor(context, kCellToOutputWeightsTensor)) ||
279             (!hasTensor(context, kCellToInputWeightsTensor) &&
280              !hasTensor(context, kCellToForgetWeightsTensor) &&
281              !hasTensor(context, kCellToOutputWeightsTensor));
282     NN_RET_CHECK(peepholeWeightsAllOrNone);
283 
284     if (!cifgUsed) {
285         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
286         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
287         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
288         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
289     } else {
290         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
291                 << "Input gate bias tensor is present when CIFG is used";
292     }
293 
294     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
295     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
296     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
297     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
298     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
299     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
300     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
301     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
302     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
303 
304     if (hasTensor(context, kProjectionWeightsTensor)) {
305         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
306         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
307         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
308         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
309     }
310 
311     if (hasTensor(context, kProjectionBiasTensor)) {
312         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
313         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
314         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
315     }
316 
317     const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
318     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
319     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
320     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
321     const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
322     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
323     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
324     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
325 
326     if (hasTensor(context, kInputLayerNormWeightsTensor)) {
327         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
328         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
329         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
330     }
331 
332     if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
333         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
334         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
335         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
336     }
337 
338     if (hasTensor(context, kCellLayerNormWeightsTensor)) {
339         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
340         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
341         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
342     }
343 
344     if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
345         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
346         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
347         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
348     }
349 
350     if (cifgUsed) {
351         NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
352                 << "Input layer norm weights tensor is present when CIFG is used";
353         const bool layerNormWeightsAllOrNoneCifg =
354                 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
355                  hasTensor(context, kCellLayerNormWeightsTensor) &&
356                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
357                 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
358                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
359                  !hasTensor(context, kOutputLayerNormWeightsTensor));
360         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
361     } else {
362         const bool layerNormWeightsAllOrNone =
363                 (hasTensor(context, kInputLayerNormWeightsTensor) &&
364                  hasTensor(context, kForgetLayerNormWeightsTensor) &&
365                  hasTensor(context, kCellLayerNormWeightsTensor) &&
366                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
367                 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
368                  !hasTensor(context, kForgetLayerNormWeightsTensor) &&
369                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
370                  !hasTensor(context, kOutputLayerNormWeightsTensor));
371         NN_RET_CHECK(layerNormWeightsAllOrNone);
372     }
373 
374     Shape outputShape = context->getInputShape(kInputTensor);
375     outputShape.dimensions[2] = outputSize;
376 
377     if (context->getNumOutputs() == kNumOutputsWithState) {
378         NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
379         NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
380 
381         Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
382         outputStateOutTensor.dimensions.resize(2);
383         outputStateOutTensor.dimensions[0] = batchSize;
384         outputStateOutTensor.dimensions[1] = outputSize;
385         NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
386 
387         Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
388         cellStateOutTensor.dimensions.resize(2);
389         cellStateOutTensor.dimensions[0] = batchSize;
390         cellStateOutTensor.dimensions[1] = numCells;
391         NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
392     }
393 
394     return context->setOutputShape(kOutputTensor, outputShape);
395 }
396 
execute(IOperationExecutionContext * context)397 bool execute(IOperationExecutionContext* context) {
398     const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
399     const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
400     const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
401     const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
402     const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
403 
404     const OperandType inputType = context->getInputType(kInputTensor);
405     switch (inputType) {
406         case OperandType::TENSOR_FLOAT32: {
407             // Initialize empty vectors and resize below only if needed
408             std::vector<float> outputStateOutBuffer;
409             std::vector<float> cellStateOutBuffer;
410             float* outputStateOut;
411             float* cellStateOut;
412             if (useStateOutTensors) {
413                 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
414                 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
415             } else {
416                 outputStateOutBuffer.resize(outputStateSize);
417                 cellStateOutBuffer.resize(cellStateSize);
418                 outputStateOut = outputStateOutBuffer.data();
419                 cellStateOut = cellStateOutBuffer.data();
420             }
421             std::vector<float> scratchBuffer(scratchSize);
422             LSTMCell::LSTMEvalFloat32(
423                     getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
424                     context->getInputShape(kInputTensor),
425                     context->getInputBuffer<float>(kInputToInputWeightsTensor),
426                     context->getInputBuffer<float>(kInputToForgetWeightsTensor),
427                     context->getInputBuffer<float>(kInputToCellWeightsTensor),
428                     context->getInputBuffer<float>(kInputToOutputWeightsTensor),
429                     context->getInputShape(kInputToOutputWeightsTensor),
430                     context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
431                     context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
432                     context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
433                     context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
434                     context->getInputShape(kRecurrentToOutputWeightsTensor),
435                     context->getInputBuffer<float>(kCellToInputWeightsTensor),
436                     context->getInputBuffer<float>(kCellToForgetWeightsTensor),
437                     context->getInputBuffer<float>(kCellToOutputWeightsTensor),
438                     /*aux_input_buffer=*/nullptr,
439                     /*aux_input_to_input_weights_buffer=*/nullptr,
440                     /*aux_input_to_forget_weights_buffer=*/nullptr,
441                     /*aux_input_to_cell_weights_buffer=*/nullptr,
442                     /*aux_input_to_output_weights_buffer=*/nullptr,
443                     context->getInputBuffer<float>(kInputGateBiasTensor),
444                     context->getInputBuffer<float>(kForgetGateBiasTensor),
445                     context->getInputBuffer<float>(kCellGateBiasTensor),
446                     context->getInputBuffer<float>(kOutputGateBiasTensor),
447                     context->getInputBuffer<float>(kProjectionWeightsTensor),
448                     context->getInputBuffer<float>(kProjectionBiasTensor),
449                     context->getInputBuffer<float>(kOutputStateInTensor),
450                     context->getInputBuffer<float>(kCellStateInTensor),
451                     context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
452                     context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
453                     context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
454                     context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
455                     cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
456                     scratchBuffer.data(), isTimeMajor(context));
457         } break;
458         case OperandType::TENSOR_FLOAT16: {
459             // Initialize empty vectors and resize below only if needed
460             std::vector<_Float16> outputStateOutBuffer;
461             std::vector<_Float16> cellStateOutBuffer;
462             _Float16* outputStateOut;
463             _Float16* cellStateOut;
464             if (useStateOutTensors) {
465                 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
466                 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
467             } else {
468                 outputStateOutBuffer.resize(outputStateSize);
469                 cellStateOutBuffer.resize(cellStateSize);
470                 outputStateOut = outputStateOutBuffer.data();
471                 cellStateOut = cellStateOutBuffer.data();
472             }
473             std::vector<_Float16> scratchBuffer(scratchSize);
474             LSTMCell::LSTMEvalFloat16(
475                     getLSTMParams<_Float16>(context),
476                     context->getInputBuffer<_Float16>(kInputTensor),
477                     context->getInputShape(kInputTensor),
478                     context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
479                     context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
480                     context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
481                     context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
482                     context->getInputShape(kInputToOutputWeightsTensor),
483                     context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
484                     context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
485                     context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
486                     context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
487                     context->getInputShape(kRecurrentToOutputWeightsTensor),
488                     context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
489                     context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
490                     context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
491                     /*aux_input_buffer=*/nullptr,
492                     /*aux_input_to_input_weights_buffer=*/nullptr,
493                     /*aux_input_to_forget_weights_buffer=*/nullptr,
494                     /*aux_input_to_cell_weights_buffer=*/nullptr,
495                     /*aux_input_to_output_weights_buffer=*/nullptr,
496                     context->getInputBuffer<_Float16>(kInputGateBiasTensor),
497                     context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
498                     context->getInputBuffer<_Float16>(kCellGateBiasTensor),
499                     context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
500                     context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
501                     context->getInputBuffer<_Float16>(kProjectionBiasTensor),
502                     context->getInputBuffer<_Float16>(kOutputStateInTensor),
503                     context->getInputBuffer<_Float16>(kCellStateInTensor),
504                     context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
505                     context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
506                     context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
507                     context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
508                     outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
509                     scratchBuffer.data(), isTimeMajor(context));
510         } break;
511         default: {
512             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
513             return false;
514         }
515     }
516     return true;
517 }
518 
519 }  // namespace unidirectional_sequence_lstm
520 
521 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
522                       unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
523                       unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
524 
525 }  // namespace nn
526 }  // namespace android
527