Home
last modified time | relevance | path

Searched refs:kProjectionWeightsTensor (Results 1 – 6 of 6) sorted by relevance

/frameworks/ml/nn/common/operations/
DUnidirectionalSequenceLSTM.cpp61 constexpr uint32_t kProjectionWeightsTensor = 16; // Optional variable
111 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); in getLSTMParams()
304 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
305 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
447 context->getInputBuffer<float>(kProjectionWeightsTensor), in execute()
500 context->getInputBuffer<_Float16>(kProjectionWeightsTensor), in execute()
DQLSTM.cpp60 constexpr uint32_t kProjectionWeightsTensor = 16; variable
284 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
285 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
386 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor); in execute()
451 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor)); in execute()
DLSTM.h84 static constexpr int kProjectionWeightsTensor = 16; // Optional variable
DLayerNormLSTMTest.cpp247 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0); in Invoke()
DLSTMTest.cpp236 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0); in Invoke()
DLSTM.cpp80 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional in LSTMCell()