Lines Matching refs:context
40 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { in hasTensor() argument
41 return context->getInputBuffer(tensor) != nullptr; in hasTensor()
44 inline bool isTimeMajor(IOperationExecutionContext* context) { in isTimeMajor() argument
45 return context->getInputValue<bool>(kTimeMajorParam); in isTimeMajor()
49 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { in getLSTMParams() argument
52 static_cast<ActivationFn>(context->getInputValue<int32_t>(kActivationParam)); in getLSTMParams()
53 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam)); in getLSTMParams()
54 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam)); in getLSTMParams()
55 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor); in getLSTMParams()
56 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor); in getLSTMParams()
57 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor); in getLSTMParams()
58 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); in getLSTMParams()
59 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor); in getLSTMParams()
65 bool prepare(IOperationExecutionContext* context) { in prepare() argument
86 NN_RET_CHECK(!context->isOmittedInput(requiredInput)) in prepare()
90 const Shape inputShape = context->getInputShape(kInputTensor); in prepare()
95 getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1); in prepare()
96 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0); in prepare()
99 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); in prepare()
104 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); in prepare()
109 if (hasTensor(context, kInputToInputWeightsTensor)) { in prepare()
110 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); in prepare()
116 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); in prepare()
120 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); in prepare()
125 if (hasTensor(context, kRecurrentToInputWeightsTensor)) { in prepare()
126 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); in prepare()
132 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); in prepare()
136 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); in prepare()
143 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && in prepare()
144 hasTensor(context, kRecurrentToInputWeightsTensor)) || in prepare()
145 (!hasTensor(context, kInputToInputWeightsTensor) && in prepare()
146 !hasTensor(context, kRecurrentToInputWeightsTensor)); in prepare()
149 if (hasTensor(context, kCellToInputWeightsTensor)) { in prepare()
150 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in prepare()
155 if (hasTensor(context, kCellToForgetWeightsTensor)) { in prepare()
156 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in prepare()
161 if (hasTensor(context, kCellToOutputWeightsTensor)) { in prepare()
162 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in prepare()
168 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); in prepare()
170 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && in prepare()
171 hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
172 hasTensor(context, kCellToOutputWeightsTensor)) || in prepare()
173 (!hasTensor(context, kCellToInputWeightsTensor) && in prepare()
174 !hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
175 !hasTensor(context, kCellToOutputWeightsTensor)); in prepare()
179 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); in prepare()
180 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); in prepare()
184 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) in prepare()
188 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); in prepare()
191 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); in prepare()
194 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); in prepare()
198 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
199 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
205 if (hasTensor(context, kProjectionBiasTensor)) { in prepare()
206 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); in prepare()
211 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor); in prepare()
215 const Shape cellStateShape = context->getInputShape(kCellStateInTensor); in prepare()
220 if (hasTensor(context, kInputLayerNormWeightsTensor)) { in prepare()
221 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor); in prepare()
226 if (hasTensor(context, kForgetLayerNormWeightsTensor)) { in prepare()
227 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor); in prepare()
232 if (hasTensor(context, kCellLayerNormWeightsTensor)) { in prepare()
233 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor); in prepare()
238 if (hasTensor(context, kOutputLayerNormWeightsTensor)) { in prepare()
239 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor); in prepare()
245 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor)) in prepare()
248 (hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
249 hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
250 hasTensor(context, kOutputLayerNormWeightsTensor)) || in prepare()
251 (!hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
252 !hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
253 !hasTensor(context, kOutputLayerNormWeightsTensor)); in prepare()
257 (hasTensor(context, kInputLayerNormWeightsTensor) && in prepare()
258 hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
259 hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
260 hasTensor(context, kOutputLayerNormWeightsTensor)) || in prepare()
261 (!hasTensor(context, kInputLayerNormWeightsTensor) && in prepare()
262 !hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
263 !hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
264 !hasTensor(context, kOutputLayerNormWeightsTensor)); in prepare()
268 Shape outputShape = context->getInputShape(kInputTensor); in prepare()
271 if (context->getNumOutputs() == kNumOutputsWithState) { in prepare()
272 NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor)); in prepare()
273 NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor)); in prepare()
275 Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor); in prepare()
279 NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor)); in prepare()
281 Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor); in prepare()
285 NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor)); in prepare()
288 return context->setOutputShape(kOutputTensor, outputShape); in prepare()
291 bool execute(IOperationExecutionContext* context) { in execute() argument
292 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor)); in execute()
293 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor)); in execute()
294 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor); in execute()
296 const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState); in execute()
298 const OperandType inputType = context->getInputType(kInputTensor); in execute()
307 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor); in execute()
308 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor); in execute()
317 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor), in execute()
318 context->getInputShape(kInputTensor), in execute()
319 context->getInputBuffer<float>(kInputToInputWeightsTensor), in execute()
320 context->getInputBuffer<float>(kInputToForgetWeightsTensor), in execute()
321 context->getInputBuffer<float>(kInputToCellWeightsTensor), in execute()
322 context->getInputBuffer<float>(kInputToOutputWeightsTensor), in execute()
323 context->getInputShape(kInputToOutputWeightsTensor), in execute()
324 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor), in execute()
325 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor), in execute()
326 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor), in execute()
327 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor), in execute()
328 context->getInputShape(kRecurrentToOutputWeightsTensor), in execute()
329 context->getInputBuffer<float>(kCellToInputWeightsTensor), in execute()
330 context->getInputBuffer<float>(kCellToForgetWeightsTensor), in execute()
331 context->getInputBuffer<float>(kCellToOutputWeightsTensor), in execute()
337 context->getInputBuffer<float>(kInputGateBiasTensor), in execute()
338 context->getInputBuffer<float>(kForgetGateBiasTensor), in execute()
339 context->getInputBuffer<float>(kCellGateBiasTensor), in execute()
340 context->getInputBuffer<float>(kOutputGateBiasTensor), in execute()
341 context->getInputBuffer<float>(kProjectionWeightsTensor), in execute()
342 context->getInputBuffer<float>(kProjectionBiasTensor), in execute()
343 context->getInputBuffer<float>(kOutputStateInTensor), in execute()
344 context->getInputBuffer<float>(kCellStateInTensor), in execute()
345 context->getInputBuffer<float>(kInputLayerNormWeightsTensor), in execute()
346 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor), in execute()
347 context->getInputBuffer<float>(kCellLayerNormWeightsTensor), in execute()
348 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut, in execute()
349 cellStateOut, context->getOutputBuffer<float>(kOutputTensor), in execute()
350 scratchBuffer.data(), isTimeMajor(context)); in execute()
359 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor); in execute()
360 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor); in execute()
369 getLSTMParams<_Float16>(context), in execute()
370 context->getInputBuffer<_Float16>(kInputTensor), in execute()
371 context->getInputShape(kInputTensor), in execute()
372 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor), in execute()
373 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor), in execute()
374 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor), in execute()
375 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor), in execute()
376 context->getInputShape(kInputToOutputWeightsTensor), in execute()
377 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor), in execute()
378 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor), in execute()
379 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor), in execute()
380 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor), in execute()
381 context->getInputShape(kRecurrentToOutputWeightsTensor), in execute()
382 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor), in execute()
383 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor), in execute()
384 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor), in execute()
390 context->getInputBuffer<_Float16>(kInputGateBiasTensor), in execute()
391 context->getInputBuffer<_Float16>(kForgetGateBiasTensor), in execute()
392 context->getInputBuffer<_Float16>(kCellGateBiasTensor), in execute()
393 context->getInputBuffer<_Float16>(kOutputGateBiasTensor), in execute()
394 context->getInputBuffer<_Float16>(kProjectionWeightsTensor), in execute()
395 context->getInputBuffer<_Float16>(kProjectionBiasTensor), in execute()
396 context->getInputBuffer<_Float16>(kOutputStateInTensor), in execute()
397 context->getInputBuffer<_Float16>(kCellStateInTensor), in execute()
398 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor), in execute()
399 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor), in execute()
400 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor), in execute()
401 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor), in execute()
402 outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor), in execute()
403 scratchBuffer.data(), isTimeMajor(context)); in execute()