• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
18 #define FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
19 
20 #include "ActivationFunctor.h"
21 #include "HalOperation.h"
22 #include "LSTM.h"
23 #include "OperationsUtils.h"
24 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
25 
26 #include <algorithm>
27 #include <cmath>
28 
29 namespace android {
30 namespace nn {
31 
32 struct RunTimeOperandInfo;
33 
34 class BidirectionalSequenceLSTM {
35    public:
36     BidirectionalSequenceLSTM(const Operation& operation,
37                               std::vector<RunTimeOperandInfo>& operands);
38 
39     bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
40                  Shape* fwOutputShape, Shape* bwOutputShape);
41     bool Eval();
42 
43     // Input Tensors of size {max_time, n_batch, n_input}
44     static constexpr int kInputTensor = 0;
45 
46     // Forward LSTM cell tensors.
47     // Input weight tensors of size: {n_cell, n_input}
48     static constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
49     static constexpr int kFwInputToForgetWeightsTensor = 2;
50     static constexpr int kFwInputToCellWeightsTensor = 3;
51     static constexpr int kFwInputToOutputWeightsTensor = 4;
52 
53     // Recurrent weight tensors of size {n_cell, n_output}
54     static constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
55     static constexpr int kFwRecurrentToForgetWeightsTensor = 6;
56     static constexpr int kFwRecurrentToCellWeightsTensor = 7;
57     static constexpr int kFwRecurrentToOutputWeightsTensor = 8;
58 
59     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
60     static constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
61     static constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
62     static constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
63 
64     // Gates bias tensors of size {n_cell}
65     static constexpr int kFwInputGateBiasTensor = 12;  // Optional
66     static constexpr int kFwForgetGateBiasTensor = 13;
67     static constexpr int kFwCellGateBiasTensor = 14;
68     static constexpr int kFwOutputGateBiasTensor = 15;
69 
70     // Projection weight tensor of size {n_output, n_cell}
71     static constexpr int kFwProjectionWeightsTensor = 16;  // Optional
72     // Projection bias tensor of size {n_output}
73     static constexpr int kFwProjectionBiasTensor = 17;  // Optional
74 
75     // Backward LSTM cell tensors.
76     // Input weight tensors of size: {n_cell, n_input}
77     static constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
78     static constexpr int kBwInputToForgetWeightsTensor = 19;
79     static constexpr int kBwInputToCellWeightsTensor = 20;
80     static constexpr int kBwInputToOutputWeightsTensor = 21;
81 
82     // Recurrent weight tensors of size {n_cell, n_output}
83     static constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
84     static constexpr int kBwRecurrentToForgetWeightsTensor = 23;
85     static constexpr int kBwRecurrentToCellWeightsTensor = 24;
86     static constexpr int kBwRecurrentToOutputWeightsTensor = 25;
87 
88     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
89     static constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
90     static constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
91     static constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
92 
93     // Gates bias tensors of size {n_cell}
94     static constexpr int kBwInputGateBiasTensor = 29;  // Optional
95     static constexpr int kBwForgetGateBiasTensor = 30;
96     static constexpr int kBwCellGateBiasTensor = 31;
97     static constexpr int kBwOutputGateBiasTensor = 32;
98 
99     // Projection weight tensor of size {n_output, n_cell}
100     static constexpr int kBwProjectionWeightsTensor = 33;  // Optional
101     // Projection bias tensor of size {n_output}
102     static constexpr int kBwProjectionBiasTensor = 34;  // Optional
103 
104     // Stateful input tensors that are variables and will be modified by the Op.
105     // Activation state tensors of size {n_batch, n_output}
106     static constexpr int kFwInputActivationStateTensor = 35;
107     // Cell state tensors of size {n_batch, n_cell}
108     static constexpr int kFwInputCellStateTensor = 36;
109     // Activation state tensors of size {n_batch, n_output}
110     static constexpr int kBwInputActivationStateTensor = 37;
111     // Cell state tensors of size {n_batch, n_cell}
112     static constexpr int kBwInputCellStateTensor = 38;
113 
114     // Used as auxiliary input and weights when stacking for
115     // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
116     // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
117     // (without cross links).
118     static constexpr int kAuxInputTensor = 39;  // Optional
119     // Forward weights.
120     static constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
121     static constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
122     static constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
123     static constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
124     // Backward weights.
125     static constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
126     static constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
127     static constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
128     static constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
129 
130     static constexpr int kActivationParam = 48;
131     static constexpr int kCellClipParam = 49;
132     static constexpr int kProjClipParam = 50;
133     static constexpr int kMergeOutputsParam = 51;
134     static constexpr int kTimeMajorParam = 52;
135 
136     // Forward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
137     static constexpr int kFwInputLayerNormWeightsTensor = 53;   // Optional
138     static constexpr int kFwForgetLayerNormWeightsTensor = 54;  // Optional
139     static constexpr int kFwCellLayerNormWeightsTensor = 55;    // Optional
140     static constexpr int kFwOutputLayerNormWeightsTensor = 56;  // Optional
141     // Backward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
142     static constexpr int kBwInputLayerNormWeightsTensor = 57;   // Optional
143     static constexpr int kBwForgetLayerNormWeightsTensor = 58;  // Optional
144     static constexpr int kBwCellLayerNormWeightsTensor = 59;    // Optional
145     static constexpr int kBwOutputLayerNormWeightsTensor = 60;  // Optional
146 
147     // Output tensors.
148     static constexpr int kFwOutputTensor = 0;
149     static constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
150 
151    private:
152     LSTMParams params_;
153     Shape fw_scratch_shape_;
154     Shape bw_scratch_shape_;
155 
156     const RunTimeOperandInfo* input_;
157 
158     const RunTimeOperandInfo* aux_input_;
159     const RunTimeOperandInfo* fw_aux_input_to_input_weights_;
160     const RunTimeOperandInfo* fw_aux_input_to_forget_weights_;
161     const RunTimeOperandInfo* fw_aux_input_to_cell_weights_;
162     const RunTimeOperandInfo* fw_aux_input_to_output_weights_;
163     const RunTimeOperandInfo* bw_aux_input_to_input_weights_;
164     const RunTimeOperandInfo* bw_aux_input_to_forget_weights_;
165     const RunTimeOperandInfo* bw_aux_input_to_cell_weights_;
166     const RunTimeOperandInfo* bw_aux_input_to_output_weights_;
167 
168     const RunTimeOperandInfo* fw_input_to_input_weights_;
169     const RunTimeOperandInfo* fw_input_to_forget_weights_;
170     const RunTimeOperandInfo* fw_input_to_cell_weights_;
171     const RunTimeOperandInfo* fw_input_to_output_weights_;
172 
173     const RunTimeOperandInfo* fw_recurrent_to_input_weights_;
174     const RunTimeOperandInfo* fw_recurrent_to_forget_weights_;
175     const RunTimeOperandInfo* fw_recurrent_to_cell_weights_;
176     const RunTimeOperandInfo* fw_recurrent_to_output_weights_;
177 
178     const RunTimeOperandInfo* fw_cell_to_input_weights_;
179     const RunTimeOperandInfo* fw_cell_to_forget_weights_;
180     const RunTimeOperandInfo* fw_cell_to_output_weights_;
181 
182     const RunTimeOperandInfo* fw_input_gate_bias_;
183     const RunTimeOperandInfo* fw_forget_gate_bias_;
184     const RunTimeOperandInfo* fw_cell_bias_;
185     const RunTimeOperandInfo* fw_output_gate_bias_;
186 
187     const RunTimeOperandInfo* fw_projection_weights_;
188     const RunTimeOperandInfo* fw_projection_bias_;
189 
190     const RunTimeOperandInfo* fw_input_layer_norm_weights_;
191     const RunTimeOperandInfo* fw_forget_layer_norm_weights_;
192     const RunTimeOperandInfo* fw_cell_layer_norm_weights_;
193     const RunTimeOperandInfo* fw_output_layer_norm_weights_;
194 
195     RunTimeOperandInfo* fw_activation_state_;
196     RunTimeOperandInfo* fw_cell_state_;
197     RunTimeOperandInfo* fw_output_;
198 
199     const RunTimeOperandInfo* bw_input_to_input_weights_;
200     const RunTimeOperandInfo* bw_input_to_forget_weights_;
201     const RunTimeOperandInfo* bw_input_to_cell_weights_;
202     const RunTimeOperandInfo* bw_input_to_output_weights_;
203 
204     const RunTimeOperandInfo* bw_recurrent_to_input_weights_;
205     const RunTimeOperandInfo* bw_recurrent_to_forget_weights_;
206     const RunTimeOperandInfo* bw_recurrent_to_cell_weights_;
207     const RunTimeOperandInfo* bw_recurrent_to_output_weights_;
208 
209     const RunTimeOperandInfo* bw_cell_to_input_weights_;
210     const RunTimeOperandInfo* bw_cell_to_forget_weights_;
211     const RunTimeOperandInfo* bw_cell_to_output_weights_;
212 
213     const RunTimeOperandInfo* bw_input_gate_bias_;
214     const RunTimeOperandInfo* bw_forget_gate_bias_;
215     const RunTimeOperandInfo* bw_cell_bias_;
216     const RunTimeOperandInfo* bw_output_gate_bias_;
217 
218     const RunTimeOperandInfo* bw_projection_weights_;
219     const RunTimeOperandInfo* bw_projection_bias_;
220 
221     const RunTimeOperandInfo* bw_input_layer_norm_weights_;
222     const RunTimeOperandInfo* bw_forget_layer_norm_weights_;
223     const RunTimeOperandInfo* bw_cell_layer_norm_weights_;
224     const RunTimeOperandInfo* bw_output_layer_norm_weights_;
225 
226     RunTimeOperandInfo* bw_activation_state_;
227     RunTimeOperandInfo* bw_cell_state_;
228     RunTimeOperandInfo* bw_output_;
229 };
230 
231 }  // namespace nn
232 }  // namespace android
233 
234 #endif  // FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
235