1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
18 #define FRAMEWORKS_ML_NN_LSTMCELL_H
19 
20 #include "ActivationFunctor.h"
21 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
22 
23 #include <algorithm>
24 #include <cmath>
25 
26 namespace android {
27 namespace hardware {
28 namespace neuralnetworks {
29 namespace V1_1 {
30 struct Operation;
31 }
32 }  // namespace neuralnetworks
33 }  // namespace hardware
34 }  // namespace android
35 
36 namespace android {
37 namespace nn {
38 
39 struct LSTMParams {
40   TfLiteFusedActivation activation_;
41   float cell_clip_;
42   float proj_clip_;
43 };
44 
45 struct RunTimeOperandInfo;
46 struct Shape;
47 
48 class LSTMCell {
49  public:
50   LSTMCell(const android::hardware::neuralnetworks::V1_1::Operation &operation,
51            std::vector<RunTimeOperandInfo> &operands);
52 
53   static bool Prepare(const android::hardware::neuralnetworks::V1_1::Operation &operation,
54                       std::vector<RunTimeOperandInfo> &operands,
55                       Shape *scratchShape,
56                       Shape *outputStateShape,
57                       Shape *cellStateShape,
58                       Shape *outputShape);
59   bool Eval();
60 
61   // Input Tensors of size {n_batch, n_input}
62   static constexpr int kInputTensor = 0;
63 
64   // Input weight tensors of size: {n_cell, n_input}
65   static constexpr int kInputToInputWeightsTensor = 1;  // Optional
66   static constexpr int kInputToForgetWeightsTensor = 2;
67   static constexpr int kInputToCellWeightsTensor = 3;
68   static constexpr int kInputToOutputWeightsTensor = 4;
69 
70   // Recurrent weight tensors of size {n_cell, n_output}
71   static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
72   static constexpr int kRecurrentToForgetWeightsTensor = 6;
73   static constexpr int kRecurrentToCellWeightsTensor = 7;
74   static constexpr int kRecurrentToOutputWeightsTensor = 8;
75 
76   // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
77   static constexpr int kCellToInputWeightsTensor = 9;    // Optional
78   static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
79   static constexpr int kCellToOutputWeightsTensor = 11;  // Optional
80 
81   // Gates bias tensors of size {n_cell}
82   static constexpr int kInputGateBiasTensor = 12;  // Optional
83   static constexpr int kForgetGateBiasTensor = 13;
84   static constexpr int kCellGateBiasTensor = 14;
85   static constexpr int kOutputGateBiasTensor = 15;
86 
87   // Projection weight tensor of size {n_output, n_cell}
88   static constexpr int kProjectionWeightsTensor = 16;  // Optional
89   // Projection bias tensor of size {n_output}
90   static constexpr int kProjectionBiasTensor = 17;  // Optional
91 
92   static constexpr int kOutputStateInTensor = 18;
93   static constexpr int kCellStateInTensor = 19;
94 
95   static constexpr int kActivationParam = 20;
96   static constexpr int kCellClipParam = 21;
97   static constexpr int kProjClipParam = 22;
98 
99   // Output tensors.
100   static constexpr int kScratchBufferTensor = 0;
101   static constexpr int kOutputStateOutTensor = 1;
102   static constexpr int kCellStateOutTensor = 2;
103   static constexpr int kOutputTensor = 3;
104 
105  private:
106   static bool CheckInputTensorDimensions(
107       const android::hardware::neuralnetworks::V1_1::Operation &operation,
108       std::vector<RunTimeOperandInfo> &operands, uint32_t n_input,
109       uint32_t n_output, uint32_t n_cell);
110   LSTMParams params_;
111 
112   const RunTimeOperandInfo *input_;
113 
114   const RunTimeOperandInfo *input_to_input_weights_;
115   const RunTimeOperandInfo *input_to_forget_weights_;
116   const RunTimeOperandInfo *input_to_cell_weights_;
117   const RunTimeOperandInfo *input_to_output_weights_;
118 
119   const RunTimeOperandInfo *recurrent_to_input_weights_;
120   const RunTimeOperandInfo *recurrent_to_forget_weights_;
121   const RunTimeOperandInfo *recurrent_to_cell_weights_;
122   const RunTimeOperandInfo *recurrent_to_output_weights_;
123 
124   const RunTimeOperandInfo *cell_to_input_weights_;
125   const RunTimeOperandInfo *cell_to_forget_weights_;
126   const RunTimeOperandInfo *cell_to_output_weights_;
127 
128   const RunTimeOperandInfo *input_gate_bias_;
129   const RunTimeOperandInfo *forget_gate_bias_;
130   const RunTimeOperandInfo *cell_bias_;
131   const RunTimeOperandInfo *output_gate_bias_;
132 
133   const RunTimeOperandInfo *projection_weights_;
134   const RunTimeOperandInfo *projection_bias_;
135 
136   const RunTimeOperandInfo *output_state_in_;
137   const RunTimeOperandInfo *cell_state_in_;
138 
139   RunTimeOperandInfo *output_state_out_;
140   RunTimeOperandInfo *cell_state_out_;
141   RunTimeOperandInfo *output_;
142 
143   RunTimeOperandInfo *scratch_buffer_;
144 };
145 
146 }  // namespace nn
147 }  // namespace android
148 
149 #endif  // FRAMEWORKS_ML_NN_LSTMCELL_H
150