1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H 18 #define FRAMEWORKS_ML_NN_LSTMCELL_H 19 20 #include "ActivationFunctor.h" 21 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" 22 23 #include <algorithm> 24 #include <cmath> 25 26 namespace android { 27 namespace hardware { 28 namespace neuralnetworks { 29 namespace V1_1 { 30 struct Operation; 31 } 32 } // namespace neuralnetworks 33 } // namespace hardware 34 } // namespace android 35 36 namespace android { 37 namespace nn { 38 39 struct LSTMParams { 40 TfLiteFusedActivation activation_; 41 float cell_clip_; 42 float proj_clip_; 43 }; 44 45 struct RunTimeOperandInfo; 46 struct Shape; 47 48 class LSTMCell { 49 public: 50 LSTMCell(const android::hardware::neuralnetworks::V1_1::Operation &operation, 51 std::vector<RunTimeOperandInfo> &operands); 52 53 static bool Prepare(const android::hardware::neuralnetworks::V1_1::Operation &operation, 54 std::vector<RunTimeOperandInfo> &operands, 55 Shape *scratchShape, 56 Shape *outputStateShape, 57 Shape *cellStateShape, 58 Shape *outputShape); 59 bool Eval(); 60 61 // Input Tensors of size {n_batch, n_input} 62 static constexpr int kInputTensor = 0; 63 64 // Input weight tensors of size: {n_cell, n_input} 65 static constexpr int kInputToInputWeightsTensor = 1; // Optional 66 static constexpr int kInputToForgetWeightsTensor = 2; 67 static constexpr int kInputToCellWeightsTensor = 3; 68 static constexpr int kInputToOutputWeightsTensor = 4; 69 70 // Recurrent weight tensors of size {n_cell, n_output} 71 static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 72 static constexpr int kRecurrentToForgetWeightsTensor = 6; 73 static constexpr int kRecurrentToCellWeightsTensor = 7; 74 static constexpr int kRecurrentToOutputWeightsTensor = 8; 75 76 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 77 static constexpr int kCellToInputWeightsTensor = 9; // Optional 78 static constexpr int kCellToForgetWeightsTensor = 10; // Optional 79 static constexpr int kCellToOutputWeightsTensor = 11; // Optional 80 81 // Gates bias tensors of size {n_cell} 82 static constexpr int kInputGateBiasTensor = 12; // Optional 83 static constexpr int kForgetGateBiasTensor = 13; 84 static constexpr int kCellGateBiasTensor = 14; 85 static constexpr int kOutputGateBiasTensor = 15; 86 87 // Projection weight tensor of size {n_output, n_cell} 88 static constexpr int kProjectionWeightsTensor = 16; // Optional 89 // Projection bias tensor of size {n_output} 90 static constexpr int kProjectionBiasTensor = 17; // Optional 91 92 static constexpr int kOutputStateInTensor = 18; 93 static constexpr int kCellStateInTensor = 19; 94 95 static constexpr int kActivationParam = 20; 96 static constexpr int kCellClipParam = 21; 97 static constexpr int kProjClipParam = 22; 98 99 // Output tensors. 100 static constexpr int kScratchBufferTensor = 0; 101 static constexpr int kOutputStateOutTensor = 1; 102 static constexpr int kCellStateOutTensor = 2; 103 static constexpr int kOutputTensor = 3; 104 105 private: 106 static bool CheckInputTensorDimensions( 107 const android::hardware::neuralnetworks::V1_1::Operation &operation, 108 std::vector<RunTimeOperandInfo> &operands, uint32_t n_input, 109 uint32_t n_output, uint32_t n_cell); 110 LSTMParams params_; 111 112 const RunTimeOperandInfo *input_; 113 114 const RunTimeOperandInfo *input_to_input_weights_; 115 const RunTimeOperandInfo *input_to_forget_weights_; 116 const RunTimeOperandInfo *input_to_cell_weights_; 117 const RunTimeOperandInfo *input_to_output_weights_; 118 119 const RunTimeOperandInfo *recurrent_to_input_weights_; 120 const RunTimeOperandInfo *recurrent_to_forget_weights_; 121 const RunTimeOperandInfo *recurrent_to_cell_weights_; 122 const RunTimeOperandInfo *recurrent_to_output_weights_; 123 124 const RunTimeOperandInfo *cell_to_input_weights_; 125 const RunTimeOperandInfo *cell_to_forget_weights_; 126 const RunTimeOperandInfo *cell_to_output_weights_; 127 128 const RunTimeOperandInfo *input_gate_bias_; 129 const RunTimeOperandInfo *forget_gate_bias_; 130 const RunTimeOperandInfo *cell_bias_; 131 const RunTimeOperandInfo *output_gate_bias_; 132 133 const RunTimeOperandInfo *projection_weights_; 134 const RunTimeOperandInfo *projection_bias_; 135 136 const RunTimeOperandInfo *output_state_in_; 137 const RunTimeOperandInfo *cell_state_in_; 138 139 RunTimeOperandInfo *output_state_out_; 140 RunTimeOperandInfo *cell_state_out_; 141 RunTimeOperandInfo *output_; 142 143 RunTimeOperandInfo *scratch_buffer_; 144 }; 145 146 } // namespace nn 147 } // namespace android 148 149 #endif // FRAMEWORKS_ML_NN_LSTMCELL_H 150