1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
18 #define FRAMEWORKS_ML_NN_LSTMCELL_H
19 
20 #include "ActivationFunctor.h"
21 #include "HalOperation.h"
22 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
23 
24 #include <algorithm>
25 #include <cmath>
26 
27 namespace android {
28 namespace nn {
29 
30 struct LSTMParams {
31     TfLiteFusedActivation activation;
32     float cell_clip;
33     float proj_clip;
34     bool use_cifg;
35     bool use_peephole;
36     bool use_layer_norm;
37     bool use_projection_weight;
38     bool use_projection_bias;
39     bool merge_outputs;
40     bool time_major;
41 };
42 
43 struct RunTimeOperandInfo;
44 struct Shape;
45 
46 class LSTMCell {
47    public:
48     LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands);
49 
50     bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
51                  Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
52                  Shape* outputShape);
53     bool Eval();
54 
55     // Input Tensors of size {n_batch, n_input}
56     static constexpr int kInputTensor = 0;
57 
58     // Input weight tensors of size: {n_cell, n_input}
59     static constexpr int kInputToInputWeightsTensor = 1;  // Optional
60     static constexpr int kInputToForgetWeightsTensor = 2;
61     static constexpr int kInputToCellWeightsTensor = 3;
62     static constexpr int kInputToOutputWeightsTensor = 4;
63 
64     // Recurrent weight tensors of size {n_cell, n_output}
65     static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
66     static constexpr int kRecurrentToForgetWeightsTensor = 6;
67     static constexpr int kRecurrentToCellWeightsTensor = 7;
68     static constexpr int kRecurrentToOutputWeightsTensor = 8;
69 
70     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
71     static constexpr int kCellToInputWeightsTensor = 9;    // Optional
72     static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
73     static constexpr int kCellToOutputWeightsTensor = 11;  // Optional
74 
75     // Gates bias tensors of size {n_cell}
76     static constexpr int kInputGateBiasTensor = 12;  // Optional
77     static constexpr int kForgetGateBiasTensor = 13;
78     static constexpr int kCellGateBiasTensor = 14;
79     static constexpr int kOutputGateBiasTensor = 15;
80 
81     // Projection weight tensor of size {n_output, n_cell}
82     static constexpr int kProjectionWeightsTensor = 16;  // Optional
83     // Projection bias tensor of size {n_output}
84     static constexpr int kProjectionBiasTensor = 17;  // Optional
85 
86     static constexpr int kOutputStateInTensor = 18;
87     static constexpr int kCellStateInTensor = 19;
88 
89     static constexpr int kActivationParam = 20;
90     static constexpr int kCellClipParam = 21;
91     static constexpr int kProjClipParam = 22;
92 
93     // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
94     static constexpr int kInputLayerNormWeightsTensor = 23;
95     static constexpr int kForgetLayerNormWeightsTensor = 24;
96     static constexpr int kCellLayerNormWeightsTensor = 25;
97     static constexpr int kOutputLayerNormWeightsTensor = 26;
98 
99     // Output tensors.
100     static constexpr int kScratchBufferTensor = 0;
101     static constexpr int kOutputStateOutTensor = 1;
102     static constexpr int kCellStateOutTensor = 2;
103     static constexpr int kOutputTensor = 3;
104 
105     static constexpr float kLayerNormEpsilon = 1e-8;
106 
107     static bool LSTMEvalFloat32(
108             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
109             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
110             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
111             const Shape& input_to_output_weights_shape,
112             const float* recurrent_to_input_weights_buffer,
113             const float* recurrent_to_forget_weights_buffer,
114             const float* recurrent_to_cell_weights_buffer,
115             const float* recurrent_to_output_weights_buffer,
116             const Shape& recurrent_to_output_weights_shape,
117             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
118             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
119             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
120             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
121             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
122             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
123             const float* projection_weights_buffer, const float* projection_bias_buffer,
124             const float* output_state_in_buffer, const float* cell_state_in_buffer,
125             const float* input_layer_norm_weights_buffer,
126             const float* forget_layer_norm_weights_buffer,
127             const float* cell_layer_norm_weights_buffer,
128             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
129             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
130             bool timeMajor = true, bool forwardSequence = true);
131 
132     static bool LSTMEvalFloat16(
133             const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
134             const _Float16* input_to_input_weights_buffer,
135             const _Float16* input_to_forget_weights_buffer,
136             const _Float16* input_to_cell_weights_buffer,
137             const _Float16* input_to_output_weights_buffer,
138             const Shape& input_to_output_weights_shape,
139             const _Float16* recurrent_to_input_weights_buffer,
140             const _Float16* recurrent_to_forget_weights_buffer,
141             const _Float16* recurrent_to_cell_weights_buffer,
142             const _Float16* recurrent_to_output_weights_buffer,
143             const Shape& recurrent_to_output_weights_shape,
144             const _Float16* cell_to_input_weights_buffer,
145             const _Float16* cell_to_forget_weights_buffer,
146             const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
147             const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
148             const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
149             const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
150             const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
151             const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
152             const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
153             const _Float16* input_layer_norm_weights_buffer,
154             const _Float16* forget_layer_norm_weights_buffer,
155             const _Float16* cell_layer_norm_weights_buffer,
156             const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
157             _Float16* cell_state_out_buffer, _Float16* output_buffer,
158             _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);
159 
160     static bool LSTMStep(
161             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
162             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
163             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
164             const Shape& input_to_output_weights_shape,
165             const float* recurrent_to_input_weights_buffer,
166             const float* recurrent_to_forget_weights_buffer,
167             const float* recurrent_to_cell_weights_buffer,
168             const float* recurrent_to_output_weights_buffer,
169             const Shape& recurrent_to_output_weights_shape,
170             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
171             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
172             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
173             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
174             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
175             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
176             const float* projection_weights_buffer, const float* projection_bias_buffer,
177             const float* output_state_in_buffer, const float* cell_state_in_buffer,
178             const float* input_layer_norm_weights_buffer,
179             const float* forget_layer_norm_weights_buffer,
180             const float* cell_layer_norm_weights_buffer,
181             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
182             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);
183 
184     static bool CheckInputTensorDimensions(
185             const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
186             const RunTimeOperandInfo* input_to_forget_weights,
187             const RunTimeOperandInfo* input_to_cell_weights,
188             const RunTimeOperandInfo* input_to_output_weights,
189             const RunTimeOperandInfo* recurrent_to_input_weights,
190             const RunTimeOperandInfo* recurrent_to_forget_weights,
191             const RunTimeOperandInfo* recurrent_to_cell_weights,
192             const RunTimeOperandInfo* recurrent_to_output_weights,
193             const RunTimeOperandInfo* cell_to_input_weights,
194             const RunTimeOperandInfo* cell_to_forget_weights,
195             const RunTimeOperandInfo* cell_to_output_weights,
196             const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
197             const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
198             const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
199             const RunTimeOperandInfo* input_layer_norm_weights,
200             const RunTimeOperandInfo* forget_layer_norm_weights,
201             const RunTimeOperandInfo* cell_layer_norm_weights,
202             const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
203             uint32_t n_output, uint32_t n_cell, LSTMParams* params);
204 
205    private:
206     LSTMParams params_;
207     const RunTimeOperandInfo* input_;
208 
209     const RunTimeOperandInfo* input_to_input_weights_;
210     const RunTimeOperandInfo* input_to_forget_weights_;
211     const RunTimeOperandInfo* input_to_cell_weights_;
212     const RunTimeOperandInfo* input_to_output_weights_;
213 
214     const RunTimeOperandInfo* recurrent_to_input_weights_;
215     const RunTimeOperandInfo* recurrent_to_forget_weights_;
216     const RunTimeOperandInfo* recurrent_to_cell_weights_;
217     const RunTimeOperandInfo* recurrent_to_output_weights_;
218 
219     const RunTimeOperandInfo* cell_to_input_weights_;
220     const RunTimeOperandInfo* cell_to_forget_weights_;
221     const RunTimeOperandInfo* cell_to_output_weights_;
222 
223     const RunTimeOperandInfo* input_gate_bias_;
224     const RunTimeOperandInfo* forget_gate_bias_;
225     const RunTimeOperandInfo* cell_bias_;
226     const RunTimeOperandInfo* output_gate_bias_;
227 
228     const RunTimeOperandInfo* projection_weights_;
229     const RunTimeOperandInfo* projection_bias_;
230 
231     const RunTimeOperandInfo* output_state_in_;
232     const RunTimeOperandInfo* cell_state_in_;
233 
234     const RunTimeOperandInfo* input_layer_norm_weights_;
235     const RunTimeOperandInfo* forget_layer_norm_weights_;
236     const RunTimeOperandInfo* cell_layer_norm_weights_;
237     const RunTimeOperandInfo* output_layer_norm_weights_;
238 
239     RunTimeOperandInfo* output_state_out_;
240     RunTimeOperandInfo* cell_state_out_;
241     RunTimeOperandInfo* output_;
242 
243     RunTimeOperandInfo* scratch_buffer_;
244 };
245 
246 }  // namespace nn
247 }  // namespace android
248 
249 #endif  // FRAMEWORKS_ML_NN_LSTMCELL_H
250