1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
16 
17 #include <algorithm>
18 
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 
21 namespace tflite {
22 namespace delegate {
23 namespace nnapi {
24 
25 // The function extracts a submatrix of the weights at a given row
26 // and column offsets from  a 2D matrix
ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray * submatrix_dims,const int32_t offset_row,const int32_t offset_column,const TfLiteIntArray * weight_dims,const uint8_t * weights,std::vector<uint8_t> * submatrix)27 void ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray* submatrix_dims,
28                                       const int32_t offset_row,
29                                       const int32_t offset_column,
30                                       const TfLiteIntArray* weight_dims,
31                                       const uint8_t* weights,
32                                       std::vector<uint8_t>* submatrix) {
33   auto const& submatrix_rows = submatrix_dims->data[0];
34   auto const& submatrix_cols = submatrix_dims->data[1];
35   auto const& weight_cols = weight_dims->data[1];
36 
37   submatrix->resize(NumElements(submatrix_dims));
38 
39   for (uint32_t i = 0, end = submatrix_rows * submatrix_cols; i < end; ++i) {
40     const uint32_t row = i / submatrix_cols;
41     const uint32_t column = i % submatrix_cols;
42     (*submatrix)[i] =
43         weights[(row + offset_row) * weight_cols + column + offset_column];
44   }
45 }
46 
OutputDepth(const TfLiteIntArray * weight_dims)47 inline int OutputDepth(const TfLiteIntArray* weight_dims) {
48   return weight_dims->data[0] / 4;
49 }
50 
InputDepth(const TfLiteIntArray * weight_dims)51 inline int InputDepth(const TfLiteIntArray* weight_dims) {
52   return weight_dims->data[1] - OutputDepth(weight_dims);
53 }
54 
SetWeightSubmatrixDims(const TfLiteIntArray * weight_dims,TfLiteIntArray * recurrent_submatrix_dims,TfLiteIntArray * input_submatrix_dims)55 void SetWeightSubmatrixDims(const TfLiteIntArray* weight_dims,
56                             TfLiteIntArray* recurrent_submatrix_dims,
57                             TfLiteIntArray* input_submatrix_dims) {
58   const auto input_depth = InputDepth(weight_dims);
59   const auto output_depth = OutputDepth(weight_dims);
60 
61   recurrent_submatrix_dims->data[0] = output_depth;
62   recurrent_submatrix_dims->data[1] = output_depth;
63 
64   input_submatrix_dims->data[0] = output_depth;
65   input_submatrix_dims->data[1] = input_depth;
66 }
67 
68 // Doing exactly the opposite work of QuantizedLSTMCell::concatenateWeights
69 // in NNAPI, decomposing the concat_weights tensor data into its 8 components
70 // according to the following diagram
71 //
72 // +-----------------------------------+
73 // | recurrentToInput  | inputToInput  |
74 // |-------------------+---------------|
75 // | recurrentToCell   | inputToCell   |
76 // |-------------------+---------------|
77 // | recurrentToForget | inputToForget |
78 // |-------------------+---------------|
79 // | recurrentToOutput | inputToOutput |
80 // +-----------------------------------+
DecomposeQuantLstmWeightsTensor(const uint8_t * concat_weights,const TfLiteIntArray * weight_dims,std::vector<uint8_t> * recurrent_to_input,std::vector<uint8_t> * input_to_input,std::vector<uint8_t> * recurrent_to_cell,std::vector<uint8_t> * input_to_cell,std::vector<uint8_t> * recurrent_to_forget,std::vector<uint8_t> * input_to_forget,std::vector<uint8_t> * recurrent_to_output,std::vector<uint8_t> * input_to_output)81 void DecomposeQuantLstmWeightsTensor(const uint8_t* concat_weights,
82                                      const TfLiteIntArray* weight_dims,
83                                      std::vector<uint8_t>* recurrent_to_input,
84                                      std::vector<uint8_t>* input_to_input,
85                                      std::vector<uint8_t>* recurrent_to_cell,
86                                      std::vector<uint8_t>* input_to_cell,
87                                      std::vector<uint8_t>* recurrent_to_forget,
88                                      std::vector<uint8_t>* input_to_forget,
89                                      std::vector<uint8_t>* recurrent_to_output,
90                                      std::vector<uint8_t>* input_to_output) {
91   const auto output_depth = OutputDepth(weight_dims);
92 
93   TfLiteIntArray* recurrent_submatrix_dims = TfLiteIntArrayCreate(2);
94   TfLiteIntArray* input_submatrix_dims = TfLiteIntArrayCreate(2);
95   SetWeightSubmatrixDims(weight_dims, recurrent_submatrix_dims,
96                          input_submatrix_dims);
97 
98   ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 0 * output_depth,
99                                    0, weight_dims, concat_weights,
100                                    recurrent_to_input);
101   ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 0 * output_depth,
102                                    output_depth, weight_dims, concat_weights,
103                                    input_to_input);
104 
105   ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 1 * output_depth,
106                                    0, weight_dims, concat_weights,
107                                    recurrent_to_cell);
108   ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 1 * output_depth,
109                                    output_depth, weight_dims, concat_weights,
110                                    input_to_cell);
111 
112   ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 2 * output_depth,
113                                    0, weight_dims, concat_weights,
114                                    recurrent_to_forget);
115   ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 2 * output_depth,
116                                    output_depth, weight_dims, concat_weights,
117                                    input_to_forget);
118 
119   ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 3 * output_depth,
120                                    0, weight_dims, concat_weights,
121                                    recurrent_to_output);
122   ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 3 * output_depth,
123                                    output_depth, weight_dims, concat_weights,
124                                    input_to_output);
125 
126   TfLiteIntArrayFree(recurrent_submatrix_dims);
127   TfLiteIntArrayFree(input_submatrix_dims);
128 }
129 
DecomposeBiasTensor(const int32_t * biases,int bias_size,std::vector<int32_t> * input_bias,std::vector<int32_t> * cell_bias,std::vector<int32_t> * forget_bias,std::vector<int32_t> * output_bias)130 void DecomposeBiasTensor(const int32_t* biases, int bias_size,
131                          std::vector<int32_t>* input_bias,
132                          std::vector<int32_t>* cell_bias,
133                          std::vector<int32_t>* forget_bias,
134                          std::vector<int32_t>* output_bias) {
135   input_bias->resize(bias_size);
136   std::copy(biases, biases + bias_size, input_bias->begin());
137 
138   cell_bias->resize(bias_size);
139   std::copy(biases + bias_size, biases + 2 * bias_size, cell_bias->begin());
140 
141   forget_bias->resize(bias_size);
142   std::copy(biases + 2 * bias_size, biases + 3 * bias_size,
143             forget_bias->begin());
144 
145   output_bias->resize(bias_size);
146   std::copy(biases + 3 * bias_size, biases + 4 * bias_size,
147             output_bias->begin());
148 }
149 
150 }  // namespace nnapi
151 }  // namespace delegate
152 }  // namespace tflite
153