Searched refs:kProjectionWeightsTensor (Results 1 – 6 of 6) sorted by relevance
/frameworks/ml/nn/common/operations/ |
D | UnidirectionalSequenceLSTM.cpp | 61 constexpr uint32_t kProjectionWeightsTensor = 16; // Optional variable 111 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); in getLSTMParams() 304 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare() 305 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare() 447 context->getInputBuffer<float>(kProjectionWeightsTensor), in execute() 500 context->getInputBuffer<_Float16>(kProjectionWeightsTensor), in execute()
|
D | QLSTM.cpp | 60 constexpr uint32_t kProjectionWeightsTensor = 16; variable 284 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare() 285 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare() 386 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor); in execute() 451 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor)); in execute()
|
D | LSTM.h | 84 static constexpr int kProjectionWeightsTensor = 16; // Optional variable
|
D | LayerNormLSTMTest.cpp | 247 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0); in Invoke()
|
D | LSTMTest.cpp | 236 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0); in Invoke()
|
D | LSTM.cpp | 80 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional in LSTMCell()
|