1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <vector>
23 
24 #include "ActivationFunctor.h"
25 #include "LSTM.h"
26 #include "OperationsValidationUtils.h"
27 
28 namespace android {
29 namespace nn {
30 
31 struct RunTimeOperandInfo;
32 
33 class BidirectionalSequenceLSTM {
34    public:
35     BidirectionalSequenceLSTM(const Operation& operation, RunTimeOperandInfo* operands);
36 
37     bool Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* fwOutputShape,
38                  Shape* bwOutputShape, Shape* fwOutputActivationState, Shape* fwOutputCellState,
39                  Shape* bwOutputActivationState, Shape* bwOutputCellState);
40     bool Eval();
41 
42     // Input Tensors of size {max_time, n_batch, n_input}
43     static constexpr int kInputTensor = 0;
44 
45     // Forward LSTM cell tensors.
46     // Input weight tensors of size: {n_cell, n_input}
47     static constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
48     static constexpr int kFwInputToForgetWeightsTensor = 2;
49     static constexpr int kFwInputToCellWeightsTensor = 3;
50     static constexpr int kFwInputToOutputWeightsTensor = 4;
51 
52     // Recurrent weight tensors of size {n_cell, n_output}
53     static constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
54     static constexpr int kFwRecurrentToForgetWeightsTensor = 6;
55     static constexpr int kFwRecurrentToCellWeightsTensor = 7;
56     static constexpr int kFwRecurrentToOutputWeightsTensor = 8;
57 
58     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
59     static constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
60     static constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
61     static constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
62 
63     // Gates bias tensors of size {n_cell}
64     static constexpr int kFwInputGateBiasTensor = 12;  // Optional
65     static constexpr int kFwForgetGateBiasTensor = 13;
66     static constexpr int kFwCellGateBiasTensor = 14;
67     static constexpr int kFwOutputGateBiasTensor = 15;
68 
69     // Projection weight tensor of size {n_output, n_cell}
70     static constexpr int kFwProjectionWeightsTensor = 16;  // Optional
71     // Projection bias tensor of size {n_output}
72     static constexpr int kFwProjectionBiasTensor = 17;  // Optional
73 
74     // Backward LSTM cell tensors.
75     // Input weight tensors of size: {n_cell, n_input}
76     static constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
77     static constexpr int kBwInputToForgetWeightsTensor = 19;
78     static constexpr int kBwInputToCellWeightsTensor = 20;
79     static constexpr int kBwInputToOutputWeightsTensor = 21;
80 
81     // Recurrent weight tensors of size {n_cell, n_output}
82     static constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
83     static constexpr int kBwRecurrentToForgetWeightsTensor = 23;
84     static constexpr int kBwRecurrentToCellWeightsTensor = 24;
85     static constexpr int kBwRecurrentToOutputWeightsTensor = 25;
86 
87     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
88     static constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
89     static constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
90     static constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
91 
92     // Gates bias tensors of size {n_cell}
93     static constexpr int kBwInputGateBiasTensor = 29;  // Optional
94     static constexpr int kBwForgetGateBiasTensor = 30;
95     static constexpr int kBwCellGateBiasTensor = 31;
96     static constexpr int kBwOutputGateBiasTensor = 32;
97 
98     // Projection weight tensor of size {n_output, n_cell}
99     static constexpr int kBwProjectionWeightsTensor = 33;  // Optional
100     // Projection bias tensor of size {n_output}
101     static constexpr int kBwProjectionBiasTensor = 34;  // Optional
102 
103     // Stateful input tensors that are variables and will be modified by the Op.
104     // Activation state tensors of size {n_batch, n_output}
105     static constexpr int kFwInputActivationStateTensor = 35;
106     // Cell state tensors of size {n_batch, n_cell}
107     static constexpr int kFwInputCellStateTensor = 36;
108     // Activation state tensors of size {n_batch, n_output}
109     static constexpr int kBwInputActivationStateTensor = 37;
110     // Cell state tensors of size {n_batch, n_cell}
111     static constexpr int kBwInputCellStateTensor = 38;
112 
113     // Used as auxiliary input and weights when stacking for
114     // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
115     // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
116     // (without cross links).
117     static constexpr int kAuxInputTensor = 39;  // Optional
118     // Forward weights.
119     static constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
120     static constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
121     static constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
122     static constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
123     // Backward weights.
124     static constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
125     static constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
126     static constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
127     static constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
128 
129     static constexpr int kActivationParam = 48;
130     static constexpr int kCellClipParam = 49;
131     static constexpr int kProjClipParam = 50;
132     static constexpr int kMergeOutputsParam = 51;
133     static constexpr int kTimeMajorParam = 52;
134 
135     // Forward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
136     static constexpr int kFwInputLayerNormWeightsTensor = 53;   // Optional
137     static constexpr int kFwForgetLayerNormWeightsTensor = 54;  // Optional
138     static constexpr int kFwCellLayerNormWeightsTensor = 55;    // Optional
139     static constexpr int kFwOutputLayerNormWeightsTensor = 56;  // Optional
140     // Backward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
141     static constexpr int kBwInputLayerNormWeightsTensor = 57;   // Optional
142     static constexpr int kBwForgetLayerNormWeightsTensor = 58;  // Optional
143     static constexpr int kBwCellLayerNormWeightsTensor = 59;    // Optional
144     static constexpr int kBwOutputLayerNormWeightsTensor = 60;  // Optional
145 
146     // Output tensors.
147     static constexpr int kFwOutputTensor = 0;
148     static constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
149 
150     static constexpr int kFwOutputActivationStateTensor = 2;
151     static constexpr int kFwOutputCellStateTensor = 3;
152     static constexpr int kBwOutputActivationStateTensor = 4;
153     static constexpr int kBwOutputCellStateTensor = 5;
154 
155    private:
156     LSTMParams params_;
157     Shape fw_scratch_shape_;
158     Shape bw_scratch_shape_;
159 
160     const RunTimeOperandInfo* input_;
161 
162     const RunTimeOperandInfo* aux_input_;
163     const RunTimeOperandInfo* fw_aux_input_to_input_weights_;
164     const RunTimeOperandInfo* fw_aux_input_to_forget_weights_;
165     const RunTimeOperandInfo* fw_aux_input_to_cell_weights_;
166     const RunTimeOperandInfo* fw_aux_input_to_output_weights_;
167     const RunTimeOperandInfo* bw_aux_input_to_input_weights_;
168     const RunTimeOperandInfo* bw_aux_input_to_forget_weights_;
169     const RunTimeOperandInfo* bw_aux_input_to_cell_weights_;
170     const RunTimeOperandInfo* bw_aux_input_to_output_weights_;
171 
172     const RunTimeOperandInfo* fw_input_to_input_weights_;
173     const RunTimeOperandInfo* fw_input_to_forget_weights_;
174     const RunTimeOperandInfo* fw_input_to_cell_weights_;
175     const RunTimeOperandInfo* fw_input_to_output_weights_;
176 
177     const RunTimeOperandInfo* fw_recurrent_to_input_weights_;
178     const RunTimeOperandInfo* fw_recurrent_to_forget_weights_;
179     const RunTimeOperandInfo* fw_recurrent_to_cell_weights_;
180     const RunTimeOperandInfo* fw_recurrent_to_output_weights_;
181 
182     const RunTimeOperandInfo* fw_cell_to_input_weights_;
183     const RunTimeOperandInfo* fw_cell_to_forget_weights_;
184     const RunTimeOperandInfo* fw_cell_to_output_weights_;
185 
186     const RunTimeOperandInfo* fw_input_gate_bias_;
187     const RunTimeOperandInfo* fw_forget_gate_bias_;
188     const RunTimeOperandInfo* fw_cell_bias_;
189     const RunTimeOperandInfo* fw_output_gate_bias_;
190 
191     const RunTimeOperandInfo* fw_projection_weights_;
192     const RunTimeOperandInfo* fw_projection_bias_;
193 
194     const RunTimeOperandInfo* fw_input_layer_norm_weights_;
195     const RunTimeOperandInfo* fw_forget_layer_norm_weights_;
196     const RunTimeOperandInfo* fw_cell_layer_norm_weights_;
197     const RunTimeOperandInfo* fw_output_layer_norm_weights_;
198 
199     const RunTimeOperandInfo* fw_activation_state_;
200     const RunTimeOperandInfo* fw_cell_state_;
201     RunTimeOperandInfo* fw_output_;
202 
203     const RunTimeOperandInfo* bw_input_to_input_weights_;
204     const RunTimeOperandInfo* bw_input_to_forget_weights_;
205     const RunTimeOperandInfo* bw_input_to_cell_weights_;
206     const RunTimeOperandInfo* bw_input_to_output_weights_;
207 
208     const RunTimeOperandInfo* bw_recurrent_to_input_weights_;
209     const RunTimeOperandInfo* bw_recurrent_to_forget_weights_;
210     const RunTimeOperandInfo* bw_recurrent_to_cell_weights_;
211     const RunTimeOperandInfo* bw_recurrent_to_output_weights_;
212 
213     const RunTimeOperandInfo* bw_cell_to_input_weights_;
214     const RunTimeOperandInfo* bw_cell_to_forget_weights_;
215     const RunTimeOperandInfo* bw_cell_to_output_weights_;
216 
217     const RunTimeOperandInfo* bw_input_gate_bias_;
218     const RunTimeOperandInfo* bw_forget_gate_bias_;
219     const RunTimeOperandInfo* bw_cell_bias_;
220     const RunTimeOperandInfo* bw_output_gate_bias_;
221 
222     const RunTimeOperandInfo* bw_projection_weights_;
223     const RunTimeOperandInfo* bw_projection_bias_;
224 
225     const RunTimeOperandInfo* bw_input_layer_norm_weights_;
226     const RunTimeOperandInfo* bw_forget_layer_norm_weights_;
227     const RunTimeOperandInfo* bw_cell_layer_norm_weights_;
228     const RunTimeOperandInfo* bw_output_layer_norm_weights_;
229 
230     const RunTimeOperandInfo* bw_activation_state_;
231     const RunTimeOperandInfo* bw_cell_state_;
232     RunTimeOperandInfo* bw_output_;
233 
234     RunTimeOperandInfo* fw_output_activation_state_;
235     RunTimeOperandInfo* fw_output_cell_state_;
236     RunTimeOperandInfo* bw_output_activation_state_;
237     RunTimeOperandInfo* bw_output_cell_state_;
238 };
239 
240 }  // namespace nn
241 }  // namespace android
242 
243 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_TYPES_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
244